Add Pytorch Tensor Parallel support for Mistral (#34927)

add base tp support
This commit is contained in:
Vladislav Bronzov 2024-11-26 14:28:07 +01:00 committed by GitHub
parent 4d1d0f29a4
commit 1141eff1bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 3 deletions

View File

@ -97,6 +97,16 @@ class MistralConfig(PretrainedConfig):
model_type = "mistral"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `MistralModel`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
def __init__(
self,

View File

@ -227,9 +227,9 @@ class MistralAttention(nn.Module):
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
@ -983,6 +983,7 @@ class MistralModel(MistralPreTrainedModel):
class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config):
super().__init__(config)