[Llama + Mistral] Add attention dropout (#27315)

* add droppouts

* add the dropout

* add doc in the config

* nits

* fix mistral config

* nits
This commit is contained in:
Arthur 2023-11-13 14:51:48 +01:00 committed by GitHub
parent b97cab7e6d
commit 210e38d83f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 15 additions and 11 deletions

View File

@ -95,7 +95,8 @@ class LlamaConfig(PretrainedConfig):
experimental feature, subject to breaking API changes in future versions.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import LlamaModel, LlamaConfig
@ -133,6 +134,7 @@ class LlamaConfig(PretrainedConfig):
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
@ -156,6 +158,7 @@ class LlamaConfig(PretrainedConfig):
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
super().__init__(
pad_token_id=pad_token_id,

View File

@ -278,6 +278,7 @@ class LlamaAttention(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
@ -292,6 +293,7 @@ class LlamaAttention(nn.Module):
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
@ -404,6 +406,7 @@ class LlamaAttention(nn.Module):
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
@ -489,10 +492,7 @@ class LlamaFlashAttention2(LlamaAttention):
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# TODO: llama does not have dropout in the config??
# It is recommended to use dropout with FA according to the docs
# when training.
dropout_rate = 0.0 # if not self.training else self.attn_dropout
dropout_rate = 0.0 if not self.training else self.attention_dropout
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need

View File

@ -82,7 +82,8 @@ class MistralConfig(PretrainedConfig):
The base period of the RoPE embeddings.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention window size. If not specified, will default to `4096`.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import MistralModel, MistralConfig
@ -119,6 +120,7 @@ class MistralConfig(PretrainedConfig):
tie_word_embeddings=False,
rope_theta=10000.0,
sliding_window=4096,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
@ -139,6 +141,7 @@ class MistralConfig(PretrainedConfig):
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
super().__init__(
pad_token_id=pad_token_id,

View File

@ -205,6 +205,7 @@ class MistralAttention(nn.Module):
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.attention_dropout = config.attention_dropout
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
@ -284,6 +285,7 @@ class MistralAttention(nn.Module):
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
@ -390,11 +392,7 @@ class MistralFlashAttention2(MistralAttention):
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# TODO: Mistral does not have dropout in the config??
# It is recommended to use dropout with FA according to the docs
# when training.
dropout_rate = 0.0 # if not self.training else self.attn_dropout
dropout_rate = 0.0 if not self.training else self.attention_dropout
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need