mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
[gpt-neox] Add attention_bias config to support model trained without attention biases (#28126)
* add attention_bias hparam for a model trained without attention biases * fix argument documentation error
This commit is contained in:
parent
def581ef51
commit
cd9f9d63f1
@ -86,6 +86,8 @@ class GPTNeoXConfig(PretrainedConfig):
|
||||
these scaling strategies behave:
|
||||
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
||||
experimental feature, subject to breaking API changes in future versions.
|
||||
attention_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
|
||||
Example:
|
||||
|
||||
@ -126,6 +128,7 @@ class GPTNeoXConfig(PretrainedConfig):
|
||||
tie_word_embeddings=False,
|
||||
use_parallel_residual=True,
|
||||
rope_scaling=None,
|
||||
attention_bias=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
@ -147,6 +150,7 @@ class GPTNeoXConfig(PretrainedConfig):
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.use_parallel_residual = use_parallel_residual
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self._rope_scaling_validation()
|
||||
|
||||
if self.hidden_size % self.num_attention_heads != 0:
|
||||
|
@ -117,8 +117,8 @@ class GPTNeoXAttention(nn.Module):
|
||||
self._init_rope()
|
||||
|
||||
self.norm_factor = self.head_size**-0.5
|
||||
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.attention_bias)
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
|
||||
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
||||
self.is_causal = True
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user