manual head_dim for mixtral model (#34281)

This commit is contained in:
Doohae Jung 2024-10-29 22:31:36 +09:00 committed by GitHub
parent 5392f12e16
commit 8755dd26b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 9 deletions

View File

@ -53,6 +53,8 @@ class MixtralConfig(PretrainedConfig):
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
The attention head dimension.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
@ -116,6 +118,7 @@ class MixtralConfig(PretrainedConfig):
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
head_dim=None,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
@ -154,6 +157,7 @@ class MixtralConfig(PretrainedConfig):
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
self.num_experts_per_tok = num_experts_per_tok
self.num_local_experts = num_local_experts

View File

@ -283,7 +283,7 @@ class MixtralAttention(nn.Module):
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.head_dim = config.head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
@ -291,11 +291,6 @@ class MixtralAttention(nn.Module):
self.is_causal = True
self.attention_dropout = config.attention_dropout
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
@ -374,7 +369,7 @@ class MixtralAttention(nn.Module):
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
@ -481,7 +476,7 @@ class MixtralFlashAttention2(MixtralAttention):
is_causal=self.is_causal,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
@ -575,7 +570,7 @@ class MixtralSdpaAttention(MixtralAttention):
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)