mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
Simplify Tensor Parallel implementation with PyTorch TP (#34184)
* Simplify Tensor Parallel implementation with PyTorch TP * Move tp_plan to config * Lint * Format and warning * Disable copy-from check * Conditionally get attr from config * make fix-copies * Move base_model_tp_plan to PretrainedConfig * Move TP into from_pretrained * Add device context for load * Do not serialize * Move _tp_plan setting to post_init * Add has_tp_plan * Add test_tp * Add 'Multi-gpu inference' doc * Add backward support for device type identification * Auto-detect accelerator * supports_tp_plan * copyright year * Fix copy
This commit is contained in:
parent
7df93d6ffb
commit
20142ab542
@ -218,6 +218,8 @@
|
||||
title: CPU inference
|
||||
- local: perf_infer_gpu_one
|
||||
title: GPU inference
|
||||
- local: perf_infer_gpu_multi
|
||||
title: Multi-GPU inference
|
||||
title: Optimizing inference
|
||||
- local: big_models
|
||||
title: Instantiate a big model
|
||||
|
68
docs/source/en/perf_infer_gpu_multi.md
Normal file
68
docs/source/en/perf_infer_gpu_multi.md
Normal file
@ -0,0 +1,68 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Multi-GPU inference
|
||||
|
||||
Built-in Tensor Parallelism (TP) is now available with certain models using PyTorch. Tensor parallelism shards a model onto multiple GPUs, enabling larger model sizes, and parallelizes computations such as matrix multiplication.
|
||||
|
||||
To enable tensor parallel, pass the argument `tp_plan="auto"` to [`~AutoModelForCausalLM.from_pretrained`]:
|
||||
|
||||
```python
|
||||
import os
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
|
||||
# Initialize distributed
|
||||
rank = int(os.environ["RANK"])
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.distributed.init_process_group("nccl", device_id=device)
|
||||
|
||||
# Retrieve tensor parallel model
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
tp_plan="auto",
|
||||
)
|
||||
|
||||
# Prepare input tokens
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
prompt = "Can I help"
|
||||
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
||||
|
||||
# Distributed run
|
||||
outputs = model(inputs)
|
||||
```
|
||||
|
||||
You can use `torchrun` to launch the above script with multiple processes, each mapping to a GPU:
|
||||
|
||||
```
|
||||
torchrun --nproc-per-node 4 demo.py
|
||||
```
|
||||
|
||||
PyTorch tensor parallel is currently supported for the following models:
|
||||
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
|
||||
|
||||
You can request to add tensor parallel support for another model by opening a GitHub Issue or Pull Request.
|
||||
|
||||
### Expected speedups
|
||||
|
||||
You can benefit from considerable speedups for inference, especially for inputs with large batch size or long sequences.
|
||||
|
||||
For a single forward pass on [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) with a sequence length of 512 and various batch sizes, the expected speedup is as follows:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/Meta-Llama-3-8B-Instruct, seqlen = 512, python, w_ compile.png">
|
||||
</div>
|
@ -53,7 +53,7 @@ sections we go through the steps to run inference on CPU and single/multi-GPU se
|
||||
|
||||
* [Inference on a single CPU](perf_infer_cpu)
|
||||
* [Inference on a single GPU](perf_infer_gpu_one)
|
||||
* [Multi-GPU inference](perf_infer_gpu_one)
|
||||
* [Multi-GPU inference](perf_infer_gpu_multi)
|
||||
* [XLA Integration for TensorFlow Models](tf_xla)
|
||||
|
||||
|
||||
|
@ -71,6 +71,8 @@ class PretrainedConfig(PushToHubMixin):
|
||||
outputs of the model during inference.
|
||||
- **attribute_map** (`Dict[str, str]`) -- A dict that maps model specific attribute names to the standardized
|
||||
naming of attributes.
|
||||
- **base_model_tp_plan** (`Dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor
|
||||
parallel plan applied to the sub-module when `model.tensor_parallel` is called.
|
||||
|
||||
Common attributes (present in all subclasses):
|
||||
|
||||
@ -194,6 +196,7 @@ class PretrainedConfig(PushToHubMixin):
|
||||
sub_configs: Dict[str, "PretrainedConfig"] = {}
|
||||
is_composition: bool = False
|
||||
attribute_map: Dict[str, str] = {}
|
||||
base_model_tp_plan: Optional[Dict[str, Any]] = None
|
||||
_auto_class: Optional[str] = None
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
@ -848,6 +851,9 @@ class PretrainedConfig(PushToHubMixin):
|
||||
|
||||
if "_attn_implementation_internal" in serializable_config_dict:
|
||||
del serializable_config_dict["_attn_implementation_internal"]
|
||||
# Do not serialize `base_model_tp_plan` for now
|
||||
if "base_model_tp_plan" in serializable_config_dict:
|
||||
del serializable_config_dict["base_model_tp_plan"]
|
||||
|
||||
return serializable_config_dict
|
||||
|
||||
@ -867,6 +873,9 @@ class PretrainedConfig(PushToHubMixin):
|
||||
del output["_commit_hash"]
|
||||
if "_attn_implementation_internal" in output:
|
||||
del output["_attn_implementation_internal"]
|
||||
# Do not serialize `base_model_tp_plan` for now
|
||||
if "base_model_tp_plan" in output:
|
||||
del output["base_model_tp_plan"]
|
||||
|
||||
# Transformers version when serializing the model
|
||||
output["transformers_version"] = __version__
|
||||
|
@ -55,6 +55,7 @@ from .pytorch_utils import ( # noqa: F401
|
||||
prune_conv1d_layer,
|
||||
prune_layer,
|
||||
prune_linear_layer,
|
||||
translate_to_torch_parallel_style,
|
||||
)
|
||||
from .quantizers import AutoHfQuantizer, HfQuantizer
|
||||
from .quantizers.quantizers_utils import get_module_from_name
|
||||
@ -1326,6 +1327,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# Has support for a `QuantoQuantizedCache` instance as `past_key_values`
|
||||
_supports_quantized_cache = False
|
||||
|
||||
# A tensor parallel plan to be applied to the model when TP is enabled. For
|
||||
# top-level models, this attribute is currently defined in respective model
|
||||
# code. For base models, this attribute comes from
|
||||
# `config.base_model_tp_plan` during `post_init`.
|
||||
_tp_plan = None
|
||||
|
||||
@property
|
||||
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
@ -1370,6 +1377,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
"""
|
||||
self.init_weights()
|
||||
self._backward_compatibility_gradient_checkpointing()
|
||||
# If current model is a base model, attach `base_model_tp_plan` from config
|
||||
if self.base_model is self:
|
||||
self._tp_plan = self.config.base_model_tp_plan
|
||||
|
||||
def dequantize(self):
|
||||
"""
|
||||
@ -3399,6 +3409,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# Cache path to the GGUF file
|
||||
gguf_path = None
|
||||
|
||||
tp_plan = kwargs.pop("tp_plan", None)
|
||||
if tp_plan is not None and tp_plan != "auto":
|
||||
# TODO: we can relax this check when we support taking tp_plan from a json file, for example.
|
||||
raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")
|
||||
|
||||
if is_fsdp_enabled():
|
||||
low_cpu_mem_usage = True
|
||||
|
||||
@ -4000,6 +4015,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
# Instantiate model.
|
||||
init_contexts = [no_init_weights(_enable=_fast_init)]
|
||||
tp_device = None
|
||||
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
@ -4012,6 +4028,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
|
||||
)
|
||||
init_contexts.append(init_empty_weights())
|
||||
elif tp_plan is not None:
|
||||
if not torch.distributed.is_initialized():
|
||||
raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.")
|
||||
|
||||
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
|
||||
device_type = torch._C._get_accelerator().type
|
||||
device_module = torch.get_device_module(device_type)
|
||||
# Get device with index assuming equal number of devices per host
|
||||
tp_device = torch.device(device_type, torch.distributed.get_rank() % device_module.device_count())
|
||||
init_contexts.append(tp_device)
|
||||
|
||||
if is_deepspeed_zero3_enabled() and is_quantized:
|
||||
init_contexts.append(set_quantized_state())
|
||||
@ -4145,32 +4171,38 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if dtype_orig is not None:
|
||||
torch.set_default_dtype(dtype_orig)
|
||||
|
||||
(
|
||||
model,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
mismatched_keys,
|
||||
offload_index,
|
||||
error_msgs,
|
||||
) = cls._load_pretrained_model(
|
||||
model,
|
||||
state_dict,
|
||||
loaded_state_dict_keys, # XXX: rename?
|
||||
resolved_archive_file,
|
||||
pretrained_model_name_or_path,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
sharded_metadata=sharded_metadata,
|
||||
_fast_init=_fast_init,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
device_map=device_map,
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
dtype=torch_dtype,
|
||||
hf_quantizer=hf_quantizer,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
gguf_path=gguf_path,
|
||||
weights_only=weights_only,
|
||||
)
|
||||
load_contexts = []
|
||||
# Make sure we load onto targeted device
|
||||
if tp_device is not None:
|
||||
load_contexts.append(tp_device)
|
||||
|
||||
with ContextManagers(load_contexts):
|
||||
(
|
||||
model,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
mismatched_keys,
|
||||
offload_index,
|
||||
error_msgs,
|
||||
) = cls._load_pretrained_model(
|
||||
model,
|
||||
state_dict,
|
||||
loaded_state_dict_keys, # XXX: rename?
|
||||
resolved_archive_file,
|
||||
pretrained_model_name_or_path,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
sharded_metadata=sharded_metadata,
|
||||
_fast_init=_fast_init,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
device_map=device_map,
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
dtype=torch_dtype,
|
||||
hf_quantizer=hf_quantizer,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
gguf_path=gguf_path,
|
||||
weights_only=weights_only,
|
||||
)
|
||||
|
||||
# make sure token embedding weights are still tied if needed
|
||||
model.tie_weights()
|
||||
@ -4254,6 +4286,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
}
|
||||
return model, loading_info
|
||||
|
||||
if tp_plan is not None:
|
||||
assert tp_device is not None, "tp_device not set!"
|
||||
if not model.supports_tp_plan:
|
||||
raise NotImplementedError("This model does not have a tensor parallel plan.")
|
||||
# Assuming sharding the model onto the world
|
||||
world_size = torch.distributed.get_world_size()
|
||||
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
|
||||
# Apply Tensor Parallelism
|
||||
model.tensor_parallel(device_mesh)
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
@ -4943,6 +4985,54 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
return self.hf_quantizer.is_trainable
|
||||
|
||||
@property
|
||||
def supports_tp_plan(self):
|
||||
"""
|
||||
Returns whether the model has a tensor parallelism plan.
|
||||
"""
|
||||
if self._tp_plan is not None:
|
||||
return True
|
||||
# Check if base model has a TP plan
|
||||
if getattr(self.base_model, "_tp_plan", None) is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
def tensor_parallel(self, device_mesh):
|
||||
"""
|
||||
Tensor parallelize the model across the given device mesh.
|
||||
|
||||
Args:
|
||||
device_mesh (`torch.distributed.DeviceMesh`):
|
||||
The device mesh to use for tensor parallelism.
|
||||
"""
|
||||
|
||||
# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
|
||||
# No op if `_tp_plan` attribute does not exist under the module.
|
||||
# This is a helper function to be used with `model.apply` to recursively
|
||||
# parallelize a model.
|
||||
def tplize(mod: torch.nn.Module) -> None:
|
||||
tp_plan = getattr(mod, "_tp_plan", None)
|
||||
if tp_plan is None:
|
||||
return
|
||||
logger.debug(f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}")
|
||||
# In model configs, we use a neutral type (string) to specify
|
||||
# parallel styles, here we translate them into torch TP types.
|
||||
# Using tree_map because `tp_plan` is a dict.
|
||||
tp_plan = torch.utils._pytree.tree_map(
|
||||
translate_to_torch_parallel_style,
|
||||
tp_plan,
|
||||
)
|
||||
# Apply TP to current module.
|
||||
torch.distributed.tensor.parallel.parallelize_module(
|
||||
mod,
|
||||
device_mesh=device_mesh,
|
||||
parallelize_plan=tp_plan,
|
||||
)
|
||||
|
||||
# `apply` is a native method of `nn.Module` that recursively applies a
|
||||
# function to every submodule.
|
||||
self.apply(tplize)
|
||||
|
||||
@property
|
||||
def loss_function(self):
|
||||
if getattr(self.config, "loss_type", None) is not None:
|
||||
|
@ -1068,7 +1068,7 @@ class CohereModel(CoherePreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere
|
||||
# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere
|
||||
class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
|
@ -720,7 +720,10 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
[GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
if getattr(config, "pretraining_tp", 1) != 1:
|
||||
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@ -982,6 +985,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
|
||||
class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -740,7 +740,10 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
[Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
if getattr(config, "pretraining_tp", 1) != 1:
|
||||
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@ -961,6 +964,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
|
||||
class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -708,6 +708,8 @@ class GlmModel(GlmPreTrainedModel):
|
||||
dim=config.head_dim // 2, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta
|
||||
)
|
||||
self.gradient_checkpointing = False
|
||||
if getattr(config, "pretraining_tp", 1) != 1:
|
||||
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@ -967,6 +969,7 @@ class GlmModel(GlmPreTrainedModel):
|
||||
|
||||
class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
|
||||
def __init__(self, config: GlmConfig):
|
||||
super().__init__(config)
|
||||
|
@ -141,6 +141,16 @@ class LlamaConfig(PretrainedConfig):
|
||||
|
||||
model_type = "llama"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
# Default tensor parallel plan for base model `LlamaModel`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -21,7 +21,6 @@ import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
@ -240,25 +239,7 @@ class LlamaMLP(nn.Module):
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
if self.config.pretraining_tp > 1:
|
||||
slice = self.intermediate_size // self.config.pretraining_tp
|
||||
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
|
||||
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
|
||||
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
|
||||
|
||||
gate_proj = torch.cat(
|
||||
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
|
||||
)
|
||||
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
|
||||
|
||||
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
|
||||
down_proj = [
|
||||
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
|
||||
]
|
||||
down_proj = sum(down_proj)
|
||||
else:
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
|
||||
@ -320,31 +301,14 @@ class LlamaAttention(nn.Module):
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split(
|
||||
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
||||
)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
# use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
|
||||
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
@ -386,12 +350,7 @@ class LlamaAttention(nn.Module):
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
||||
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
@ -564,9 +523,10 @@ class LlamaSdpaAttention(LlamaAttention):
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
# use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
|
||||
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
@ -850,7 +810,10 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
)
|
||||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
if getattr(config, "pretraining_tp", 1) != 1:
|
||||
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@ -1113,6 +1076,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -1211,13 +1175,8 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if self.config.pretraining_tp > 1:
|
||||
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
||||
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
logits = torch.cat(logits, dim=-1)
|
||||
else:
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -980,7 +980,7 @@ class NemotronModel(NemotronPreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
|
||||
# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
|
||||
class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
|
@ -1020,7 +1020,7 @@ class OlmoModel(OlmoPreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO,Llama->Olmo
|
||||
# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO,Llama->Olmo
|
||||
class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
|
@ -971,6 +971,7 @@ class Olmo1124Model(Olmo1124PreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
|
||||
# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO_1124,Llama->Olmo1124
|
||||
class Olmo1124ForCausalLM(Olmo1124PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
|
@ -888,7 +888,7 @@ OLMOE_INPUTS_DOCSTRING = r"""
|
||||
"The bare Olmoe Model outputting raw hidden-states without any specific head on top.",
|
||||
OLMOE_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Olmoe
|
||||
# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Olmoe
|
||||
class OlmoeModel(OlmoePreTrainedModel):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OlmoeDecoderLayer`]
|
||||
|
@ -775,7 +775,7 @@ class RecurrentGemmaModel(RecurrentGemmaPreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->RECURRENTGEMMA,Llama->RecurrentGemma,llama->gemma
|
||||
# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->RECURRENTGEMMA,Llama->RecurrentGemma,llama->gemma
|
||||
class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
|
@ -20,6 +20,11 @@ import torch
|
||||
from packaging import version
|
||||
from safetensors.torch import storage_ptr, storage_size
|
||||
from torch import nn
|
||||
from torch.distributed.tensor import Replicate
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
RowwiseParallel,
|
||||
)
|
||||
|
||||
from .utils import is_torch_xla_available, logging
|
||||
|
||||
@ -329,3 +334,22 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int)
|
||||
else:
|
||||
# Note: don't use named arguments in `torch.isin`, see https://github.com/pytorch/pytorch/issues/126045
|
||||
return torch.isin(elements, test_elements)
|
||||
|
||||
|
||||
def translate_to_torch_parallel_style(style: str):
|
||||
"""
|
||||
In model configurations, we use a neutral type (string) to specify parallel
|
||||
styles, here we translate them into torch.distributed tensor-parallel
|
||||
types.
|
||||
"""
|
||||
if not isinstance(style, str):
|
||||
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
|
||||
|
||||
if style == "colwise":
|
||||
return ColwiseParallel()
|
||||
elif style == "rowwise":
|
||||
return RowwiseParallel()
|
||||
elif style == "colwise_rep":
|
||||
return ColwiseParallel(output_layouts=Replicate())
|
||||
else:
|
||||
raise ValueError(f"Unsupported parallel style value: {style}")
|
||||
|
91
tests/tp/test_tp.py
Normal file
91
tests/tp/test_tp.py
Normal file
@ -0,0 +1,91 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import LlamaModel
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
execute_subprocess_async,
|
||||
get_torch_dist_unique_port,
|
||||
require_torch_multi_gpu,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class TestTensorParallel(TestCasePlus):
|
||||
@require_torch_multi_gpu
|
||||
def test_tp(self):
|
||||
distributed_args = f"""--nproc_per_node={torch.cuda.device_count()}
|
||||
--master_port={get_torch_dist_unique_port()}
|
||||
{self.test_file_dir}/test_tp.py
|
||||
""".split()
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
args = f"--output_dir {output_dir} --report_to none".split()
|
||||
cmd = ["torchrun"] + distributed_args + args
|
||||
print(cmd)
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
# successful return here == success - any errors would have caused an error in the sub-call
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# The script below is meant to be run under torch.distributed, on a machine with multiple GPUs:
|
||||
# CUDA_VISIBLE_DEVICES=0,1 RUN_SLOW=1 pytest -sv tests/tp/test_tp.py
|
||||
# or
|
||||
# PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 ./tests/tp/test_tp.py
|
||||
|
||||
if not is_torch_available():
|
||||
exit(0)
|
||||
|
||||
# Test settings
|
||||
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
bs = 4
|
||||
seqlen = 64
|
||||
|
||||
# Get distributed settings
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
# Initialize distributed
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.distributed.init_process_group("nccl", device_id=device)
|
||||
device_mesh = torch.distributed.init_device_mesh("cuda", (world_size,))
|
||||
|
||||
# Get model config
|
||||
config = LlamaConfig.from_pretrained(model_id)
|
||||
# Shrink model size
|
||||
config.num_hidden_layers //= 8
|
||||
config.vocab_size //= 8
|
||||
|
||||
# Instantiate model
|
||||
with device:
|
||||
model = LlamaModel(config)
|
||||
|
||||
model.eval()
|
||||
|
||||
# Tensor Parallel
|
||||
if world_size > 1:
|
||||
model.tensor_parallel(device_mesh)
|
||||
|
||||
# Run model
|
||||
inputs = torch.randint(config.vocab_size, (bs, seqlen), device=device)
|
||||
with torch.no_grad():
|
||||
out = model(inputs)
|
||||
|
||||
assert out.last_hidden_state.shape == torch.Size([bs, seqlen, config.hidden_size])
|
Loading…
Reference in New Issue
Block a user