Update configuration_qwen2.py (#36735)

* Update configuration_qwen2_moe.py

* Update modeling_qwen2_moe.py

* ruff fmt

* docstring add qkv_bias
This commit is contained in:
Michael Feil 2025-03-19 11:15:54 -07:00 committed by GitHub
parent 107fedc1e2
commit 51bd0ceb9e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 10 additions and 4 deletions

View File

@ -133,7 +133,8 @@ class Qwen2MoeConfig(PretrainedConfig):
Indicate which layers use Qwen2MoeMLP rather than Qwen2MoeSparseMoeBlock
The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
qkv_bias (`bool`, *optional*, defaults to `True`):
Whether to add a bias to the queries, keys and values.
```python
>>> from transformers import Qwen2MoeModel, Qwen2MoeConfig
@ -195,6 +196,7 @@ class Qwen2MoeConfig(PretrainedConfig):
output_router_logits=False,
router_aux_loss_coef=0.001,
mlp_only_layers=None,
qkv_bias=True,
**kwargs,
):
self.vocab_size = vocab_size
@ -231,6 +233,7 @@ class Qwen2MoeConfig(PretrainedConfig):
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
self.qkv_bias = qkv_bias
super().__init__(
tie_word_embeddings=tie_word_embeddings,

View File

@ -327,9 +327,9 @@ class Qwen2MoeAttention(nn.Module):
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=self.config.qkv_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.config.qkv_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.config.qkv_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.rotary_emb = Qwen2MoeRotaryEmbedding(config=self.config)

View File

@ -89,6 +89,7 @@ class Qwen2MoeModelTester:
pad_token_id=0,
bos_token_id=1,
scope=None,
qkv_bias=False,
):
self.parent = parent
self.batch_size = batch_size
@ -127,6 +128,7 @@ class Qwen2MoeModelTester:
self.norm_topk_prob = norm_topk_prob
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.qkv_bias = qkv_bias
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs
def prepare_config_and_inputs(self):
@ -181,6 +183,7 @@ class Qwen2MoeModelTester:
initializer_range=self.initializer_range,
pad_token_id=self.pad_token_id,
bos_token_id=self.bos_token_id,
qkv_bias=self.qkv_bias,
)
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Qwen2Moe