mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Llama + Mistral
] Add attention dropout (#27315)
* add droppouts * add the dropout * add doc in the config * nits * fix mistral config * nits
This commit is contained in:
parent
b97cab7e6d
commit
210e38d83f
@ -95,7 +95,8 @@ class LlamaConfig(PretrainedConfig):
|
||||
experimental feature, subject to breaking API changes in future versions.
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
|
||||
```python
|
||||
>>> from transformers import LlamaModel, LlamaConfig
|
||||
@ -133,6 +134,7 @@ class LlamaConfig(PretrainedConfig):
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
@ -156,6 +158,7 @@ class LlamaConfig(PretrainedConfig):
|
||||
self.rope_scaling = rope_scaling
|
||||
self._rope_scaling_validation()
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
|
@ -278,6 +278,7 @@ class LlamaAttention(nn.Module):
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
@ -292,6 +293,7 @@ class LlamaAttention(nn.Module):
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||
@ -404,6 +406,7 @@ class LlamaAttention(nn.Module):
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
@ -489,10 +492,7 @@ class LlamaFlashAttention2(LlamaAttention):
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
# TODO: llama does not have dropout in the config??
|
||||
# It is recommended to use dropout with FA according to the docs
|
||||
# when training.
|
||||
dropout_rate = 0.0 # if not self.training else self.attn_dropout
|
||||
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
|
@ -82,7 +82,8 @@ class MistralConfig(PretrainedConfig):
|
||||
The base period of the RoPE embeddings.
|
||||
sliding_window (`int`, *optional*, defaults to 4096):
|
||||
Sliding window attention window size. If not specified, will default to `4096`.
|
||||
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
|
||||
```python
|
||||
>>> from transformers import MistralModel, MistralConfig
|
||||
@ -119,6 +120,7 @@ class MistralConfig(PretrainedConfig):
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
sliding_window=4096,
|
||||
attention_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
@ -139,6 +141,7 @@ class MistralConfig(PretrainedConfig):
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
|
@ -205,6 +205,7 @@ class MistralAttention(nn.Module):
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
self.attention_dropout = config.attention_dropout
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
@ -284,6 +285,7 @@ class MistralAttention(nn.Module):
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
@ -390,11 +392,7 @@ class MistralFlashAttention2(MistralAttention):
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
# TODO: Mistral does not have dropout in the config??
|
||||
# It is recommended to use dropout with FA according to the docs
|
||||
# when training.
|
||||
dropout_rate = 0.0 # if not self.training else self.attn_dropout
|
||||
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
|
Loading…
Reference in New Issue
Block a user