mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Allow-head-dim (#32857)
* support head dim * fix the doc * fixup * add oproj Co-authored-by: Suhara <suhara@users.noreply.github.com>> * update Co-authored-by: bzantium <bzantium@users.noreply.github.com> * Co-authored-by: suhara <suhara@users.noreply.github.com> * Update Co-authored-by: Yoshi Suhara <suhara@users.noreply.github.com> --------- Co-authored-by: bzantium <bzantium@users.noreply.github.com> Co-authored-by: Yoshi Suhara <suhara@users.noreply.github.com>
This commit is contained in:
parent
85345bb439
commit
13e645bb40
@ -123,6 +123,8 @@ class LlamaConfig(PretrainedConfig):
|
||||
The dropout ratio for the attention probabilities.
|
||||
mlp_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
|
||||
head_dim (`int`, *optional*):
|
||||
The attention head dimension. If None, it will default to hidden_size // num_heads
|
||||
|
||||
```python
|
||||
>>> from transformers import LlamaModel, LlamaConfig
|
||||
@ -163,6 +165,7 @@ class LlamaConfig(PretrainedConfig):
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
mlp_bias=False,
|
||||
head_dim=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
@ -187,7 +190,7 @@ class LlamaConfig(PretrainedConfig):
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.mlp_bias = mlp_bias
|
||||
|
||||
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
|
@ -214,12 +214,6 @@ class FlaxLlamaAttention(nn.Module):
|
||||
self.k_proj = dense(self.num_key_value_heads * self.head_dim)
|
||||
self.v_proj = dense(self.num_key_value_heads * self.head_dim)
|
||||
self.o_proj = dense(self.embed_dim)
|
||||
if (self.head_dim * self.num_heads) != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.embed_dim}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
|
||||
self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
|
||||
self.rotary_emb = FlaxLlamaRotaryEmbedding(config, dtype=self.dtype)
|
||||
|
||||
|
@ -340,23 +340,17 @@ class LlamaAttention(nn.Module):
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
||||
|
||||
# TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers)
|
||||
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
|
||||
@ -420,7 +414,6 @@ class LlamaAttention(nn.Module):
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
|
Loading…
Reference in New Issue
Block a user