mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Adding Qwen3 and Qwen3MoE (#36878)
* Initial commit for Qwen3 * fix and add tests for qwen3 & qwen3_moe * rename models for tests. * fix * fix * fix and add docs. * fix model name in docs. * simplify modular and fix configuration issues * Fix the red CI: ruff was updated * revert ruff, version was wrong * fix qwen3moe. * fix * make sure MOE can load * fix copies --------- Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
This commit is contained in:
parent
0d6a60fe55
commit
6acd5aecb3
@ -603,6 +603,10 @@
|
||||
title: Qwen2
|
||||
- local: model_doc/qwen2_moe
|
||||
title: Qwen2MoE
|
||||
- local: model_doc/qwen3
|
||||
title: Qwen3
|
||||
- local: model_doc/qwen3_moe
|
||||
title: Qwen3MoE
|
||||
- local: model_doc/rag
|
||||
title: RAG
|
||||
- local: model_doc/realm
|
||||
|
@ -43,4 +43,3 @@ Transformers is designed for developers and machine learning engineers and resea
|
||||
</a>
|
||||
</div>
|
||||
|
||||
Join us on the Hugging Face [Hub](https://huggingface.co/), [Discord](https://discord.com/invite/JfAtkvEtRb), or [forum](https://discuss.huggingface.co/) to collaborate and build models, datasets, and applications together.
|
||||
|
59
docs/source/en/model_doc/qwen3.md
Normal file
59
docs/source/en/model_doc/qwen3.md
Normal file
@ -0,0 +1,59 @@
|
||||
<!--Copyright 2024 The Qwen Team and The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Qwen3
|
||||
|
||||
## Overview
|
||||
|
||||
To be released with the official model launch.
|
||||
|
||||
### Model Details
|
||||
|
||||
To be released with the official model launch.
|
||||
|
||||
|
||||
## Usage tips
|
||||
|
||||
To be released with the official model launch.
|
||||
|
||||
## Qwen3Config
|
||||
|
||||
[[autodoc]] Qwen3Config
|
||||
|
||||
## Qwen3Model
|
||||
|
||||
[[autodoc]] Qwen3Model
|
||||
- forward
|
||||
|
||||
## Qwen3ForCausalLM
|
||||
|
||||
[[autodoc]] Qwen3ForCausalLM
|
||||
- forward
|
||||
|
||||
## Qwen3ForSequenceClassification
|
||||
|
||||
[[autodoc]] Qwen3ForSequenceClassification
|
||||
- forward
|
||||
|
||||
## Qwen3ForTokenClassification
|
||||
|
||||
[[autodoc]] Qwen3ForTokenClassification
|
||||
- forward
|
||||
|
||||
## Qwen3ForQuestionAnswering
|
||||
|
||||
[[autodoc]] Qwen3ForQuestionAnswering
|
||||
- forward
|
58
docs/source/en/model_doc/qwen3_moe.md
Normal file
58
docs/source/en/model_doc/qwen3_moe.md
Normal file
@ -0,0 +1,58 @@
|
||||
<!--Copyright 2024 The Qwen Team and The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Qwen3MoE
|
||||
|
||||
## Overview
|
||||
|
||||
To be released with the official model launch.
|
||||
|
||||
### Model Details
|
||||
|
||||
To be released with the official model launch.
|
||||
|
||||
## Usage tips
|
||||
|
||||
To be released with the official model launch.
|
||||
|
||||
## Qwen3MoeConfig
|
||||
|
||||
[[autodoc]] Qwen3MoeConfig
|
||||
|
||||
## Qwen3MoeModel
|
||||
|
||||
[[autodoc]] Qwen3MoeModel
|
||||
- forward
|
||||
|
||||
## Qwen3MoeForCausalLM
|
||||
|
||||
[[autodoc]] Qwen3MoeForCausalLM
|
||||
- forward
|
||||
|
||||
## Qwen3MoeForSequenceClassification
|
||||
|
||||
[[autodoc]] Qwen3MoeForSequenceClassification
|
||||
- forward
|
||||
|
||||
## Qwen3MoeForTokenClassification
|
||||
|
||||
[[autodoc]] Qwen3MoeForTokenClassification
|
||||
- forward
|
||||
|
||||
## Qwen3MoeForQuestionAnswering
|
||||
|
||||
[[autodoc]] Qwen3MoeForQuestionAnswering
|
||||
- forward
|
@ -744,6 +744,8 @@ _import_structure = {
|
||||
"Qwen2VLConfig",
|
||||
"Qwen2VLProcessor",
|
||||
],
|
||||
"models.qwen3": ["Qwen3Config"],
|
||||
"models.qwen3_moe": ["Qwen3MoeConfig"],
|
||||
"models.rag": ["RagConfig", "RagRetriever", "RagTokenizer"],
|
||||
"models.recurrent_gemma": ["RecurrentGemmaConfig"],
|
||||
"models.reformer": ["ReformerConfig"],
|
||||
@ -3441,6 +3443,26 @@ else:
|
||||
"Qwen2VLPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.qwen3"].extend(
|
||||
[
|
||||
"Qwen3ForCausalLM",
|
||||
"Qwen3ForQuestionAnswering",
|
||||
"Qwen3ForSequenceClassification",
|
||||
"Qwen3ForTokenClassification",
|
||||
"Qwen3Model",
|
||||
"Qwen3PreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.qwen3_moe"].extend(
|
||||
[
|
||||
"Qwen3MoeForCausalLM",
|
||||
"Qwen3MoeForQuestionAnswering",
|
||||
"Qwen3MoeForSequenceClassification",
|
||||
"Qwen3MoeForTokenClassification",
|
||||
"Qwen3MoeModel",
|
||||
"Qwen3MoePreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.rag"].extend(
|
||||
[
|
||||
"RagModel",
|
||||
@ -5993,6 +6015,8 @@ if TYPE_CHECKING:
|
||||
Qwen2VLConfig,
|
||||
Qwen2VLProcessor,
|
||||
)
|
||||
from .models.qwen3 import Qwen3Config
|
||||
from .models.qwen3_moe import Qwen3MoeConfig
|
||||
from .models.rag import RagConfig, RagRetriever, RagTokenizer
|
||||
from .models.recurrent_gemma import RecurrentGemmaConfig
|
||||
from .models.reformer import ReformerConfig
|
||||
@ -8293,6 +8317,22 @@ if TYPE_CHECKING:
|
||||
Qwen2VLModel,
|
||||
Qwen2VLPreTrainedModel,
|
||||
)
|
||||
from .models.qwen3 import (
|
||||
Qwen3ForCausalLM,
|
||||
Qwen3ForQuestionAnswering,
|
||||
Qwen3ForSequenceClassification,
|
||||
Qwen3ForTokenClassification,
|
||||
Qwen3Model,
|
||||
Qwen3PreTrainedModel,
|
||||
)
|
||||
from .models.qwen3_moe import (
|
||||
Qwen3MoeForCausalLM,
|
||||
Qwen3MoeForQuestionAnswering,
|
||||
Qwen3MoeForSequenceClassification,
|
||||
Qwen3MoeForTokenClassification,
|
||||
Qwen3MoeModel,
|
||||
Qwen3MoePreTrainedModel,
|
||||
)
|
||||
from .models.rag import (
|
||||
RagModel,
|
||||
RagPreTrainedModel,
|
||||
|
@ -230,6 +230,8 @@ from . import (
|
||||
qwen2_audio,
|
||||
qwen2_moe,
|
||||
qwen2_vl,
|
||||
qwen3,
|
||||
qwen3_moe,
|
||||
rag,
|
||||
recurrent_gemma,
|
||||
reformer,
|
||||
|
@ -254,6 +254,8 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("qwen2_audio_encoder", "Qwen2AudioEncoderConfig"),
|
||||
("qwen2_moe", "Qwen2MoeConfig"),
|
||||
("qwen2_vl", "Qwen2VLConfig"),
|
||||
("qwen3", "Qwen3Config"),
|
||||
("qwen3_moe", "Qwen3MoeConfig"),
|
||||
("rag", "RagConfig"),
|
||||
("realm", "RealmConfig"),
|
||||
("recurrent_gemma", "RecurrentGemmaConfig"),
|
||||
@ -609,6 +611,8 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("qwen2_audio_encoder", "Qwen2AudioEncoder"),
|
||||
("qwen2_moe", "Qwen2MoE"),
|
||||
("qwen2_vl", "Qwen2VL"),
|
||||
("qwen3", "Qwen3"),
|
||||
("qwen3_moe", "Qwen3MoE"),
|
||||
("rag", "RAG"),
|
||||
("realm", "REALM"),
|
||||
("recurrent_gemma", "RecurrentGemma"),
|
||||
|
@ -233,6 +233,8 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("qwen2_audio_encoder", "Qwen2AudioEncoder"),
|
||||
("qwen2_moe", "Qwen2MoeModel"),
|
||||
("qwen2_vl", "Qwen2VLModel"),
|
||||
("qwen3", "Qwen3Model"),
|
||||
("qwen3_moe", "Qwen3MoeModel"),
|
||||
("recurrent_gemma", "RecurrentGemmaModel"),
|
||||
("reformer", "ReformerModel"),
|
||||
("regnet", "RegNetModel"),
|
||||
@ -576,6 +578,8 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("qdqbert", "QDQBertLMHeadModel"),
|
||||
("qwen2", "Qwen2ForCausalLM"),
|
||||
("qwen2_moe", "Qwen2MoeForCausalLM"),
|
||||
("qwen3", "Qwen3ForCausalLM"),
|
||||
("qwen3_moe", "Qwen3MoeForCausalLM"),
|
||||
("recurrent_gemma", "RecurrentGemmaForCausalLM"),
|
||||
("reformer", "ReformerModelWithLMHead"),
|
||||
("rembert", "RemBertForCausalLM"),
|
||||
@ -1072,6 +1076,8 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("qdqbert", "QDQBertForSequenceClassification"),
|
||||
("qwen2", "Qwen2ForSequenceClassification"),
|
||||
("qwen2_moe", "Qwen2MoeForSequenceClassification"),
|
||||
("qwen3", "Qwen3ForSequenceClassification"),
|
||||
("qwen3_moe", "Qwen3MoeForSequenceClassification"),
|
||||
("reformer", "ReformerForSequenceClassification"),
|
||||
("rembert", "RemBertForSequenceClassification"),
|
||||
("roberta", "RobertaForSequenceClassification"),
|
||||
@ -1153,6 +1159,8 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
("qdqbert", "QDQBertForQuestionAnswering"),
|
||||
("qwen2", "Qwen2ForQuestionAnswering"),
|
||||
("qwen2_moe", "Qwen2MoeForQuestionAnswering"),
|
||||
("qwen3", "Qwen3ForQuestionAnswering"),
|
||||
("qwen3_moe", "Qwen3MoeForQuestionAnswering"),
|
||||
("reformer", "ReformerForQuestionAnswering"),
|
||||
("rembert", "RemBertForQuestionAnswering"),
|
||||
("roberta", "RobertaForQuestionAnswering"),
|
||||
@ -1257,6 +1265,8 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("qdqbert", "QDQBertForTokenClassification"),
|
||||
("qwen2", "Qwen2ForTokenClassification"),
|
||||
("qwen2_moe", "Qwen2MoeForTokenClassification"),
|
||||
("qwen3", "Qwen3ForTokenClassification"),
|
||||
("qwen3_moe", "Qwen3MoeForTokenClassification"),
|
||||
("rembert", "RemBertForTokenClassification"),
|
||||
("roberta", "RobertaForTokenClassification"),
|
||||
("roberta-prelayernorm", "RobertaPreLayerNormForTokenClassification"),
|
||||
|
@ -454,6 +454,20 @@ else:
|
||||
),
|
||||
),
|
||||
("qwen2_vl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"qwen3",
|
||||
(
|
||||
"Qwen2Tokenizer",
|
||||
"Qwen2TokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"qwen3_moe",
|
||||
(
|
||||
"Qwen2Tokenizer",
|
||||
"Qwen2TokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("rag", ("RagTokenizer", None)),
|
||||
("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
|
@ -26,8 +26,7 @@ class Qwen2MoeConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen2MoeModel`]. It is used to instantiate a
|
||||
Qwen2MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of
|
||||
Qwen1.5-MoE-A2.7B" [Qwen/Qwen1.5-MoE-A2.7B"](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B").
|
||||
with the defaults will yield a similar configuration to that of [Qwen/Qwen1.5-MoE-A2.7B](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B).
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
27
src/transformers/models/qwen3/__init__.py
Normal file
27
src/transformers/models/qwen3/__init__.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright 2024 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_qwen3 import *
|
||||
from .modeling_qwen3 import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
212
src/transformers/models/qwen3/configuration_qwen3.py
Normal file
212
src/transformers/models/qwen3/configuration_qwen3.py
Normal file
@ -0,0 +1,212 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Qwen3 model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Qwen3Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen3Model`]. It is used to instantiate a
|
||||
Qwen3 model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of
|
||||
Qwen3-8B [Qwen/Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B).
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 151936):
|
||||
Vocabulary size of the Qwen3 model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Qwen3Model`]
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 22016):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 32):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
||||
head_dim (`int`, *optional*, defaults to 128):
|
||||
The attention head dimension.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use sliding window attention.
|
||||
sliding_window (`int`, *optional*, defaults to 4096):
|
||||
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
||||
max_window_layers (`int`, *optional*, defaults to 28):
|
||||
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
|
||||
```python
|
||||
>>> from transformers import Qwen3Model, Qwen3Config
|
||||
|
||||
>>> # Initializing a Qwen3 style configuration
|
||||
>>> configuration = Qwen3Config()
|
||||
|
||||
>>> # Initializing a model from the Qwen3-8B style configuration
|
||||
>>> model = Qwen3Model(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "qwen3"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
# Default tensor parallel plan for base model `Qwen3`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=151936,
|
||||
hidden_size=4096,
|
||||
intermediate_size=22016,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=32,
|
||||
head_dim=128,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
use_sliding_window=False,
|
||||
sliding_window=4096,
|
||||
max_window_layers=28,
|
||||
attention_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window # we check `use_sliding_window` in the modeling code
|
||||
self.max_window_layers = max_window_layers
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Qwen3Config"]
|
1202
src/transformers/models/qwen3/modeling_qwen3.py
Normal file
1202
src/transformers/models/qwen3/modeling_qwen3.py
Normal file
File diff suppressed because it is too large
Load Diff
203
src/transformers/models/qwen3/modular_qwen3.py
Normal file
203
src/transformers/models/qwen3/modular_qwen3.py
Normal file
@ -0,0 +1,203 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch Qwen3 model."""
|
||||
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ...cache_utils import Cache
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import CausalLMOutputWithPast
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
logging,
|
||||
)
|
||||
from ..gemma.modeling_gemma import GemmaMLP
|
||||
from ..llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
LlamaForCausalLM,
|
||||
LlamaForQuestionAnswering,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaForTokenClassification,
|
||||
LlamaRMSNorm,
|
||||
apply_rotary_pos_emb,
|
||||
eager_attention_forward,
|
||||
)
|
||||
from ..mistral.modeling_mistral import MistralModel
|
||||
from .configuration_qwen3 import Qwen3Config
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "Qwen/Qwen3-8B"
|
||||
|
||||
|
||||
class Qwen3RMSNorm(LlamaRMSNorm):
|
||||
pass
|
||||
|
||||
|
||||
class Qwen3MLP(GemmaMLP):
|
||||
pass
|
||||
|
||||
|
||||
class Qwen3Attention(LlamaAttention):
|
||||
def __init__(self, config: Qwen3Config, layer_idx: int):
|
||||
super().__init__(config, layer_idx)
|
||||
self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
|
||||
self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
|
||||
self.sliding_window = config.sliding_window
|
||||
if not (
|
||||
self.config.use_sliding_window
|
||||
and getattr(self.config, "sliding_window", None) is not None
|
||||
and self.layer_idx >= self.config.max_window_layers
|
||||
):
|
||||
self.sliding_window = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
||||
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
sliding_window=self.sliding_window, # diff with Llama
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Qwen3DecoderLayer(LlamaDecoderLayer):
|
||||
def __init__(self, config: Qwen3Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx)
|
||||
self.mlp = Qwen3MLP(config)
|
||||
if (
|
||||
config.sliding_window and config._attn_implementation != "flash_attention_2"
|
||||
): # diff with Llama is this warning
|
||||
logger.warning_once(
|
||||
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
|
||||
"unexpected results may be encountered."
|
||||
)
|
||||
|
||||
|
||||
class Qwen3Model(MistralModel): # mistral model creates sliding window
|
||||
pass
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
class Qwen3ForCausalLM(LlamaForCausalLM):
|
||||
def forward(
|
||||
self,
|
||||
**super_kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, Qwen3ForCausalLM
|
||||
|
||||
>>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
return super().forward(**super_kwargs)
|
||||
|
||||
|
||||
class Qwen3ForSequenceClassification(LlamaForSequenceClassification):
|
||||
pass
|
||||
|
||||
|
||||
class Qwen3ForTokenClassification(LlamaForTokenClassification):
|
||||
pass
|
||||
|
||||
|
||||
class Qwen3ForQuestionAnswering(LlamaForQuestionAnswering):
|
||||
pass
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Qwen3ForCausalLM",
|
||||
"Qwen3ForQuestionAnswering",
|
||||
"Qwen3Model",
|
||||
"Qwen3PreTrainedModel", # noqa: F822
|
||||
"Qwen3ForSequenceClassification",
|
||||
"Qwen3ForTokenClassification",
|
||||
]
|
27
src/transformers/models/qwen3_moe/__init__.py
Normal file
27
src/transformers/models/qwen3_moe/__init__.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_qwen3_moe import *
|
||||
from .modeling_qwen3_moe import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
240
src/transformers/models/qwen3_moe/configuration_qwen3_moe.py
Normal file
240
src/transformers/models/qwen3_moe/configuration_qwen3_moe.py
Normal file
@ -0,0 +1,240 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Qwen3MoE model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Qwen3MoeConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen3MoeModel`]. It is used to instantiate a
|
||||
Qwen3MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of [Qwen/Qwen3-MoE-15B-A2B](https://huggingface.co/Qwen/Qwen3-15B-A2B).
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 151936):
|
||||
Vocabulary size of the Qwen3MoE model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Qwen3MoeModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 2048):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 6144):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 24):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 4):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use sliding window attention.
|
||||
sliding_window (`int`, *optional*, defaults to 4096):
|
||||
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
||||
max_window_layers (`int`, *optional*, defaults to 28):
|
||||
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
decoder_sparse_step (`int`, *optional*, defaults to 1):
|
||||
The frequency of the MoE layer.
|
||||
moe_intermediate_size (`int`, *optional*, defaults to 768):
|
||||
Intermediate size of the routed expert.
|
||||
num_experts_per_tok (`int`, *optional*, defaults to 8):
|
||||
Number of selected experts.
|
||||
num_experts (`int`, *optional*, defaults to 128):
|
||||
Number of routed experts.
|
||||
norm_topk_prob (`bool`, *optional*, defaults to `False`):
|
||||
Whether to normalize the topk probabilities.
|
||||
output_router_logits (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the router logits should be returned by the model. Enabeling this will also
|
||||
allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
|
||||
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
|
||||
The aux loss factor for the total loss.
|
||||
mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):
|
||||
Indicate which layers use Qwen3MoeMLP rather than Qwen3MoeSparseMoeBlock
|
||||
The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
|
||||
If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
|
||||
|
||||
```python
|
||||
>>> from transformers import Qwen3MoeModel, Qwen3MoeConfig
|
||||
|
||||
>>> # Initializing a Qwen3MoE style configuration
|
||||
>>> configuration = Qwen3MoeConfig()
|
||||
|
||||
>>> # Initializing a model from the Qwen3-15B-A2B" style configuration
|
||||
>>> model = Qwen3MoeModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "qwen3_moe"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
# Default tensor parallel plan for base model `Qwen3Moe`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=151936,
|
||||
hidden_size=2048,
|
||||
intermediate_size=6144,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=4,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
use_sliding_window=False,
|
||||
sliding_window=4096,
|
||||
max_window_layers=28,
|
||||
attention_dropout=0.0,
|
||||
decoder_sparse_step=1,
|
||||
moe_intermediate_size=768,
|
||||
num_experts_per_tok=8,
|
||||
num_experts=128,
|
||||
norm_topk_prob=False,
|
||||
output_router_logits=False,
|
||||
router_aux_loss_coef=0.001,
|
||||
mlp_only_layers=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window if use_sliding_window else None
|
||||
self.max_window_layers = max_window_layers
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
# MoE arguments
|
||||
self.decoder_sparse_step = decoder_sparse_step
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.num_experts = num_experts
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.output_router_logits = output_router_logits
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
|
||||
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Qwen3MoeConfig"]
|
1412
src/transformers/models/qwen3_moe/modeling_qwen3_moe.py
Normal file
1412
src/transformers/models/qwen3_moe/modeling_qwen3_moe.py
Normal file
File diff suppressed because it is too large
Load Diff
376
src/transformers/models/qwen3_moe/modular_qwen3_moe.py
Normal file
376
src/transformers/models/qwen3_moe/modular_qwen3_moe.py
Normal file
@ -0,0 +1,376 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch Qwen3 model."""
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import MoeCausalLMOutputWithPast
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
logging,
|
||||
)
|
||||
from ..llama.modeling_llama import (
|
||||
LlamaForQuestionAnswering,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaForTokenClassification,
|
||||
LlamaRMSNorm,
|
||||
)
|
||||
from ..mixtral.modeling_mixtral import (
|
||||
MixtralForCausalLM,
|
||||
MixtralModel,
|
||||
load_balancing_loss_func,
|
||||
)
|
||||
from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeDecoderLayer
|
||||
from ..qwen3.modeling_qwen3 import Qwen3Attention
|
||||
from .configuration_qwen3_moe import Qwen3MoeConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "Qwen/Qwen3-MoE-15B-A2B"
|
||||
|
||||
|
||||
class Qwen3MoeAttention(Qwen3Attention): # This is the main diff with qwen2Moe!
|
||||
pass
|
||||
|
||||
|
||||
class Qwen3MoeMLP(nn.Module):
|
||||
def __init__(self, config, intermediate_size=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
|
||||
class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.num_experts = config.num_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
|
||||
# gating
|
||||
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
|
||||
self.experts = nn.ModuleList(
|
||||
[Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
""" """
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
if self.norm_topk_prob: # only diff with mixtral sparse moe block!
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
final_hidden_states = torch.zeros(
|
||||
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
|
||||
# One hot encode the selected experts to create an expert mask
|
||||
# this will be used to easily index which expert is going to be sollicitated
|
||||
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
|
||||
# Loop over all available experts in the model and perform the computation on each expert
|
||||
for expert_idx in range(self.num_experts):
|
||||
expert_layer = self.experts[expert_idx]
|
||||
idx, top_x = torch.where(expert_mask[expert_idx])
|
||||
|
||||
# Index the correct hidden states and compute the expert hidden state for
|
||||
# the current expert. We need to make sure to multiply the output hidden
|
||||
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
||||
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
|
||||
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
|
||||
|
||||
# However `index_add_` only support torch tensors for indexing so we'll use
|
||||
# the `top_x` tensor here.
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
||||
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return final_hidden_states, router_logits
|
||||
|
||||
|
||||
class Qwen3MoeRMSNorm(LlamaRMSNorm):
|
||||
pass
|
||||
|
||||
|
||||
class Qwen3MoeDecoderLayer(Qwen2MoeDecoderLayer, nn.Module):
|
||||
def __init__(self, config: Qwen3MoeConfig, layer_idx: int):
|
||||
nn.Module().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = Qwen3MoeAttention(config, layer_idx)
|
||||
self.mlp = Qwen3MoeMLP(config)
|
||||
|
||||
self.self_attn = Qwen3MoeAttention(config, layer_idx)
|
||||
|
||||
if (layer_idx not in config.mlp_only_layers) and (
|
||||
config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
|
||||
):
|
||||
self.mlp = Qwen3MoeSparseMoeBlock(config)
|
||||
else:
|
||||
self.mlp = Qwen3MoeMLP(config, intermediate_size=config.intermediate_size)
|
||||
|
||||
self.input_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_router_logits: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, sequence_length)` where padding elements are indicated by 0.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
output_router_logits (`bool`, *optional*):
|
||||
Whether or not to return the logits of all the routers. They are useful for computing the router loss,
|
||||
and should not be returned during inference.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||
with `head_dim` being the embedding dimension of each attention head.
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
if isinstance(hidden_states, tuple):
|
||||
hidden_states, router_logits = hidden_states
|
||||
else:
|
||||
router_logits = None
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if output_router_logits:
|
||||
outputs += (router_logits,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class Qwen3MoeModel(MixtralModel):
|
||||
def __init__(self, config: Qwen3MoeConfig):
|
||||
super().__init__(config)
|
||||
self.layers = nn.ModuleList(
|
||||
[Qwen3MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
class Qwen3MoeForCausalLM(MixtralForCausalLM):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = Qwen3MoeModel(config)
|
||||
self.num_experts = config.num_experts
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_router_logits: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM
|
||||
|
||||
>>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_router_logits = (
|
||||
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
||||
)
|
||||
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_router_logits=output_router_logits,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
aux_loss = load_balancing_loss_func(
|
||||
outputs.router_logits if return_dict else outputs[-1],
|
||||
self.num_experts,
|
||||
self.num_experts_per_tok,
|
||||
attention_mask,
|
||||
)
|
||||
if labels is not None:
|
||||
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
if output_router_logits:
|
||||
output = (aux_loss,) + output
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return MoeCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
aux_loss=aux_loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
router_logits=outputs.router_logits,
|
||||
)
|
||||
|
||||
|
||||
class Qwen3MoeForSequenceClassification(LlamaForSequenceClassification):
|
||||
pass
|
||||
|
||||
|
||||
class Qwen3MoeForTokenClassification(LlamaForTokenClassification):
|
||||
pass
|
||||
|
||||
|
||||
class Qwen3MoeForQuestionAnswering(LlamaForQuestionAnswering):
|
||||
pass
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Qwen3MoeForCausalLM",
|
||||
"Qwen3MoeForQuestionAnswering",
|
||||
"Qwen3MoeModel",
|
||||
"Qwen3MoePreTrainedModel", # noqa: F822
|
||||
"Qwen3MoeForSequenceClassification",
|
||||
"Qwen3MoeForTokenClassification",
|
||||
]
|
@ -8222,6 +8222,90 @@ class Qwen2VLPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Qwen3ForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Qwen3ForQuestionAnswering(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Qwen3ForSequenceClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Qwen3ForTokenClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Qwen3Model(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Qwen3PreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Qwen3MoeForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Qwen3MoeForQuestionAnswering(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Qwen3MoeForSequenceClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Qwen3MoeForTokenClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Qwen3MoeModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Qwen3MoePreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class RagModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -158,6 +158,8 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
|
||||
"plbart",
|
||||
"qwen2",
|
||||
"qwen2_moe",
|
||||
"qwen3",
|
||||
"qwen3_moe",
|
||||
"resnet",
|
||||
"roberta",
|
||||
"segformer",
|
||||
|
0
tests/models/qwen3/__init__.py
Normal file
0
tests/models/qwen3/__init__.py
Normal file
1039
tests/models/qwen3/test_modeling_qwen3.py
Normal file
1039
tests/models/qwen3/test_modeling_qwen3.py
Normal file
File diff suppressed because it is too large
Load Diff
0
tests/models/qwen3_moe/__init__.py
Normal file
0
tests/models/qwen3_moe/__init__.py
Normal file
631
tests/models/qwen3_moe/test_modeling_qwen3_moe.py
Normal file
631
tests/models/qwen3_moe/test_modeling_qwen3_moe.py
Normal file
@ -0,0 +1,631 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch Qwen3MoE model."""
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import AutoTokenizer, Qwen3MoeConfig, is_torch_available, set_seed
|
||||
from transformers.testing_utils import (
|
||||
backend_empty_cache,
|
||||
require_bitsandbytes,
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_sdpa,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
Qwen3MoeForCausalLM,
|
||||
Qwen3MoeForQuestionAnswering,
|
||||
Qwen3MoeForSequenceClassification,
|
||||
Qwen3MoeForTokenClassification,
|
||||
Qwen3MoeModel,
|
||||
)
|
||||
|
||||
|
||||
class Qwen3MoeModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
max_window_layers=3,
|
||||
use_sliding_window=True,
|
||||
sliding_window=50,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
head_dim=16,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
expert_interval=1,
|
||||
moe_intermediate_size=12,
|
||||
num_experts_per_tok=2,
|
||||
num_experts=8,
|
||||
norm_topk_prob=False,
|
||||
output_router_logits=False,
|
||||
router_aux_loss_coef=0.001,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.max_window_layers = max_window_layers
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.pad_token_id = pad_token_id
|
||||
self.bos_token_id = bos_token_id
|
||||
self.scope = scope
|
||||
self.expert_interval = expert_interval
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.num_experts = num_experts
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.output_router_logits = output_router_logits
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def get_config(self):
|
||||
return Qwen3MoeConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
max_window_layers=self.max_window_layers,
|
||||
use_sliding_window=self.use_sliding_window,
|
||||
sliding_window=self.sliding_window,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
num_key_value_heads=self.num_key_value_heads,
|
||||
head_dim=self.head_dim,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
expert_interval=self.expert_interval,
|
||||
moe_intermediate_size=self.moe_intermediate_size,
|
||||
num_experts_per_tok=self.num_experts_per_tok,
|
||||
num_experts=self.num_experts,
|
||||
norm_topk_prob=self.norm_topk_prob,
|
||||
output_router_logits=self.output_router_logits,
|
||||
router_aux_loss_coef=self.router_aux_loss_coef,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
pad_token_id=self.pad_token_id,
|
||||
bos_token_id=self.bos_token_id,
|
||||
)
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Qwen3Moe
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = Qwen3MoeModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model_as_decoder with Llama->Qwen3Moe
|
||||
def create_and_check_model_as_decoder(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
config.add_cross_attention = True
|
||||
model = Qwen3MoeModel(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_for_causal_lm with Llama->Qwen3Moe
|
||||
def create_and_check_for_causal_lm(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
model = Qwen3MoeForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_decoder_model_past_large_inputs with Llama->Qwen3Moe
|
||||
def create_and_check_decoder_model_past_large_inputs(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
config.is_decoder = True
|
||||
config.add_cross_attention = True
|
||||
model = Qwen3MoeForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# first forward pass
|
||||
outputs = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
past_key_values = outputs.past_key_values
|
||||
|
||||
# create hypothetical multiple next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
|
||||
|
||||
# append to next input_ids and
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
|
||||
|
||||
output_from_no_past = model(
|
||||
next_input_ids,
|
||||
attention_mask=next_attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_hidden_states=True,
|
||||
)["hidden_states"][0]
|
||||
output_from_past = model(
|
||||
next_tokens,
|
||||
attention_mask=next_attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
output_hidden_states=True,
|
||||
)["hidden_states"][0]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
||||
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
||||
|
||||
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen3Moe
|
||||
class Qwen3MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(
|
||||
Qwen3MoeModel,
|
||||
Qwen3MoeForCausalLM,
|
||||
Qwen3MoeForSequenceClassification,
|
||||
Qwen3MoeForTokenClassification,
|
||||
Qwen3MoeForQuestionAnswering,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": Qwen3MoeModel,
|
||||
"text-classification": Qwen3MoeForSequenceClassification,
|
||||
"token-classification": Qwen3MoeForTokenClassification,
|
||||
"text-generation": Qwen3MoeForCausalLM,
|
||||
"zero-shot": Qwen3MoeForSequenceClassification,
|
||||
"question-answering": Qwen3MoeForQuestionAnswering,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
fx_compatible = False # Broken by attention refactor cc @Cyrilvallez
|
||||
|
||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
||||
def is_pipeline_test_to_skip(
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
return True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Qwen3MoeModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=Qwen3MoeConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_various_embeddings(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||
config_and_inputs[0].position_embedding_type = type
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_torch_fx_output_loss(self):
|
||||
super().test_torch_fx_output_loss()
|
||||
|
||||
def test_Qwen3Moe_sequence_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
print(config)
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
||||
model = Qwen3MoeForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
def test_Qwen3Moe_sequence_classification_model_for_single_label(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
config.problem_type = "single_label_classification"
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
||||
model = Qwen3MoeForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
def test_Qwen3Moe_sequence_classification_model_for_multi_label(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
config.problem_type = "multi_label_classification"
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
sequence_labels = ids_tensor(
|
||||
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
||||
).to(torch.float)
|
||||
model = Qwen3MoeForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Qwen3Moe,llama->Qwen3Moe
|
||||
def test_Qwen3Moe_token_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||
model = Qwen3MoeForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||
self.assertEqual(
|
||||
result.logits.shape,
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Qwen3Moe buffers include complex numbers, which breaks this test")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Qwen3Moe uses GQA on all models so the KV cache is a non standard format")
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||
self.skipTest(reason="Qwen3Moe flash attention does not support right padding")
|
||||
|
||||
# Ignore copy
|
||||
def test_load_balancing_loss(self):
|
||||
r"""
|
||||
Let's make sure we can actually compute the loss and do a backward on it.
|
||||
"""
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
config.num_experts = 8
|
||||
config.expert_interval = 2
|
||||
config.output_router_logits = True
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
model = Qwen3MoeForCausalLM(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask)
|
||||
self.assertEqual(result.router_logits[0].shape, (91, config.num_experts))
|
||||
torch.testing.assert_close(result.aux_loss.cpu(), torch.tensor(2, dtype=torch.float32), rtol=1e-2, atol=1e-2)
|
||||
|
||||
# First, we make sure that adding padding tokens doesn't change the loss
|
||||
# loss(input_ids, attention_mask=None) == loss(input_ids + padding, attention_mask=attention_mask_with_padding)
|
||||
pad_length = 1000
|
||||
# Add padding tokens (assume that pad_token_id=1) to input_ids
|
||||
padding_block = torch.ones(input_ids.shape[0], pad_length, dtype=torch.int32).to(torch_device)
|
||||
padded_input_ids = torch.cat((padding_block, input_ids), dim=1) # this is to simulate padding to the left
|
||||
padded_attention_mask = padded_input_ids.ne(1).to(torch_device)
|
||||
|
||||
padded_result = model(padded_input_ids, attention_mask=padded_attention_mask)
|
||||
torch.testing.assert_close(result.aux_loss.cpu(), padded_result.aux_loss.cpu(), rtol=1e-4, atol=1e-4)
|
||||
|
||||
# We make sure that the loss of includding padding tokens != the loss without padding tokens
|
||||
# if attention_mask=None --> we don't exclude padding tokens
|
||||
include_padding_result = model(padded_input_ids, attention_mask=None)
|
||||
|
||||
# This is to mimic torch.testing.assert_not_close
|
||||
self.assertNotAlmostEqual(include_padding_result.aux_loss.item(), result.aux_loss.item())
|
||||
|
||||
|
||||
@require_torch
|
||||
class Qwen3MoeIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_model_15b_a2b_logits(self):
|
||||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
||||
model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-15B-A2B-Base", device_map="auto")
|
||||
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
|
||||
with torch.no_grad():
|
||||
out = model(input_ids).logits.float().cpu()
|
||||
# Expected mean on dim = -1
|
||||
EXPECTED_MEAN = torch.tensor([[-1.1184, 1.1356, 9.2112, 8.0254, 5.1663, 7.9287, 8.9245, 10.0671]])
|
||||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2)
|
||||
# slicing logits[0, 0, 0:30]
|
||||
EXPECTED_SLICE = torch.tensor([7.5938, 2.6094, 4.0312, 4.0938, 2.5156, 2.7812, 2.9688, 1.5547, 1.3984, 2.2344, 3.0156, 3.1562, 1.1953, 3.2500, 1.0938, 8.4375, 9.5625, 9.0625, 7.5625, 7.5625, 7.9062, 7.2188, 7.0312, 6.9375, 8.0625, 1.7266, 0.9141, 3.7969, 5.3438, 3.9844]) # fmt: skip
|
||||
print(out[0, 0, :30])
|
||||
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-4, atol=1e-4)
|
||||
|
||||
del model
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
@slow
|
||||
def test_model_15b_a2b_generation(self):
|
||||
EXPECTED_TEXT_COMPLETION = (
|
||||
"""To be or not to be, that is the question. Whether 'tis nobler in the mind to suffer the sl"""
|
||||
)
|
||||
prompt = "To be or not to"
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-15B-A2B-Base", use_fast=False)
|
||||
model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-15B-A2B-Base", device_map="auto")
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
|
||||
|
||||
# greedy generation outputs
|
||||
generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0)
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
del model
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
@require_bitsandbytes
|
||||
@slow
|
||||
@require_flash_attn
|
||||
@pytest.mark.flash_attn_test
|
||||
def test_model_15b_a2b_long_prompt(self):
|
||||
EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
|
||||
# An input with 4097 tokens that is above the size of the sliding window
|
||||
input_ids = [1] + [306, 338] * 2048
|
||||
model = Qwen3MoeForCausalLM.from_pretrained(
|
||||
"Qwen/Qwen3-15B-A2B-Base",
|
||||
device_map="auto",
|
||||
load_in_4bit=True,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
|
||||
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
|
||||
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
|
||||
|
||||
# Assisted generation
|
||||
assistant_model = model
|
||||
assistant_model.generation_config.num_assistant_tokens = 2
|
||||
assistant_model.generation_config.num_assistant_tokens_schedule = "constant"
|
||||
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
|
||||
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
|
||||
|
||||
del assistant_model
|
||||
del model
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
@slow
|
||||
@require_torch_sdpa
|
||||
def test_model_15b_a2b_long_prompt_sdpa(self):
|
||||
EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
|
||||
# An input with 4097 tokens that is above the size of the sliding window
|
||||
input_ids = [1] + [306, 338] * 2048
|
||||
model = Qwen3MoeForCausalLM.from_pretrained(
|
||||
"Qwen/Qwen3-15B-A2B-Base",
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa",
|
||||
)
|
||||
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
|
||||
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
|
||||
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
|
||||
|
||||
# Assisted generation
|
||||
assistant_model = model
|
||||
assistant_model.generation_config.num_assistant_tokens = 2
|
||||
assistant_model.generation_config.num_assistant_tokens_schedule = "constant"
|
||||
generated_ids = assistant_model.generate(input_ids, max_new_tokens=4, temperature=0)
|
||||
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
|
||||
|
||||
del assistant_model
|
||||
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
EXPECTED_TEXT_COMPLETION = (
|
||||
"""To be or not to be, that is the question. Whether 'tis nobler in the mind to suffer the sl"""
|
||||
)
|
||||
prompt = "To be or not to"
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-15B-A2B-Base", use_fast=False)
|
||||
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
|
||||
|
||||
# greedy generation outputs
|
||||
generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0)
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
@slow
|
||||
def test_speculative_generation(self):
|
||||
EXPECTED_TEXT_COMPLETION = (
|
||||
"To be or not to be, that is the question: whether 'tis nobler in the mind to suffer the sl"
|
||||
)
|
||||
prompt = "To be or not to"
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-15B-A2B-Base", use_fast=False)
|
||||
model = Qwen3MoeForCausalLM.from_pretrained(
|
||||
"Qwen/Qwen3-15B-A2B-Base", device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
assistant_model = Qwen3MoeForCausalLM.from_pretrained(
|
||||
"Qwen/Qwen3-15B-A2B-Base", device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
|
||||
|
||||
# greedy generation outputs
|
||||
set_seed(0)
|
||||
generated_ids = model.generate(
|
||||
input_ids, max_new_tokens=20, do_sample=True, temperature=0.3, assistant_model=assistant_model
|
||||
)
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
del model
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
@ -22,6 +22,8 @@ FILES_TO_PARSE = [
|
||||
os.path.join(MODEL_ROOT, "olmo", "modular_olmo.py"),
|
||||
os.path.join(MODEL_ROOT, "rt_detr", "modular_rt_detr.py"),
|
||||
os.path.join(MODEL_ROOT, "qwen2", "modular_qwen2.py"),
|
||||
os.path.join(MODEL_ROOT, "qwen3", "modular_qwen3.py"),
|
||||
os.path.join(MODEL_ROOT, "qwen3", "modular_qwen3_moe.py"),
|
||||
os.path.join(MODEL_ROOT, "llava_next_video", "modular_llava_next_video.py"),
|
||||
os.path.join(MODEL_ROOT, "cohere2", "modular_cohere2.py"),
|
||||
os.path.join(MODEL_ROOT, "modernbert", "modular_modernbert.py"),
|
||||
|
@ -47,6 +47,7 @@ CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
|
||||
"LlamaConfig",
|
||||
"GraniteConfig",
|
||||
"GraniteMoeConfig",
|
||||
"Qwen3MoeConfig",
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user