Allow-head-dim (#32857)

* support head dim

* fix the doc

* fixup

* add oproj

Co-authored-by: Suhara
<suhara@users.noreply.github.com>>

* update

Co-authored-by: bzantium <bzantium@users.noreply.github.com>

* Co-authored-by: suhara <suhara@users.noreply.github.com>

* Update

Co-authored-by: Yoshi Suhara <suhara@users.noreply.github.com>

---------

Co-authored-by: bzantium <bzantium@users.noreply.github.com>
Co-authored-by: Yoshi Suhara <suhara@users.noreply.github.com>
This commit is contained in:
Arthur 2024-08-20 10:24:48 +02:00 committed by GitHub
parent 85345bb439
commit 13e645bb40
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 6 additions and 16 deletions

View File

@ -123,6 +123,8 @@ class LlamaConfig(PretrainedConfig):
The dropout ratio for the attention probabilities.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
head_dim (`int`, *optional*):
The attention head dimension. If None, it will default to hidden_size // num_heads
```python
>>> from transformers import LlamaModel, LlamaConfig
@ -163,6 +165,7 @@ class LlamaConfig(PretrainedConfig):
attention_bias=False,
attention_dropout=0.0,
mlp_bias=False,
head_dim=None,
**kwargs,
):
self.vocab_size = vocab_size
@ -187,7 +190,7 @@ class LlamaConfig(PretrainedConfig):
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:

View File

@ -214,12 +214,6 @@ class FlaxLlamaAttention(nn.Module):
self.k_proj = dense(self.num_key_value_heads * self.head_dim)
self.v_proj = dense(self.num_key_value_heads * self.head_dim)
self.o_proj = dense(self.embed_dim)
if (self.head_dim * self.num_heads) != self.embed_dim:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.embed_dim}"
f" and `num_heads`: {self.num_heads})."
)
self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
self.rotary_emb = FlaxLlamaRotaryEmbedding(config, dtype=self.dtype)

View File

@ -340,23 +340,17 @@ class LlamaAttention(nn.Module):
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
# TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers)
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
@ -420,7 +414,6 @@ class LlamaAttention(nn.Module):
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it