mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Update configuration_qwen2.py (#36735)
* Update configuration_qwen2_moe.py * Update modeling_qwen2_moe.py * ruff fmt * docstring add qkv_bias
This commit is contained in:
parent
107fedc1e2
commit
51bd0ceb9e
@ -133,7 +133,8 @@ class Qwen2MoeConfig(PretrainedConfig):
|
||||
Indicate which layers use Qwen2MoeMLP rather than Qwen2MoeSparseMoeBlock
|
||||
The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
|
||||
If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
|
||||
|
||||
qkv_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to add a bias to the queries, keys and values.
|
||||
```python
|
||||
>>> from transformers import Qwen2MoeModel, Qwen2MoeConfig
|
||||
|
||||
@ -195,6 +196,7 @@ class Qwen2MoeConfig(PretrainedConfig):
|
||||
output_router_logits=False,
|
||||
router_aux_loss_coef=0.001,
|
||||
mlp_only_layers=None,
|
||||
qkv_bias=True,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
@ -231,6 +233,7 @@ class Qwen2MoeConfig(PretrainedConfig):
|
||||
self.output_router_logits = output_router_logits
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
|
||||
self.qkv_bias = qkv_bias
|
||||
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
|
@ -327,9 +327,9 @@ class Qwen2MoeAttention(nn.Module):
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=self.config.qkv_bias)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.config.qkv_bias)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.config.qkv_bias)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
|
||||
self.rotary_emb = Qwen2MoeRotaryEmbedding(config=self.config)
|
||||
|
@ -89,6 +89,7 @@ class Qwen2MoeModelTester:
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
scope=None,
|
||||
qkv_bias=False,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@ -127,6 +128,7 @@ class Qwen2MoeModelTester:
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.output_router_logits = output_router_logits
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
self.qkv_bias = qkv_bias
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs
|
||||
def prepare_config_and_inputs(self):
|
||||
@ -181,6 +183,7 @@ class Qwen2MoeModelTester:
|
||||
initializer_range=self.initializer_range,
|
||||
pad_token_id=self.pad_token_id,
|
||||
bos_token_id=self.bos_token_id,
|
||||
qkv_bias=self.qkv_bias,
|
||||
)
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Qwen2Moe
|
||||
|
Loading…
Reference in New Issue
Block a user