mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[mistral] Support passing head_dim
through config (and do not require head_dim * num_heads == hidden_size
) (#32050)
* Allow `head_dim` to be set in Mistral config * Add docstring * Do not require `head_dim * num_heads == hidden_size` * [run-slow] mistral
This commit is contained in:
parent
c50e0551fd
commit
4c040aba02
@ -53,6 +53,8 @@ class MistralConfig(PretrainedConfig):
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
|
||||
head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
|
||||
The attention head dimension.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
|
||||
@ -104,6 +106,7 @@ class MistralConfig(PretrainedConfig):
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=8,
|
||||
head_dim=None,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=4096 * 32,
|
||||
initializer_range=0.02,
|
||||
@ -125,6 +128,7 @@ class MistralConfig(PretrainedConfig):
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.head_dim = head_dim or hidden_size // num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
|
@ -185,22 +185,17 @@ class MistralAttention(nn.Module):
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.head_dim = config.head_dim
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
|
||||
self.rotary_emb = MistralRotaryEmbedding(
|
||||
self.head_dim,
|
||||
|
Loading…
Reference in New Issue
Block a user