From 6acd5aecb34b579729c29a26aa470f92081d88fa Mon Sep 17 00:00:00 2001 From: Bo Zheng Date: Mon, 31 Mar 2025 15:50:49 +0800 Subject: [PATCH] Adding Qwen3 and Qwen3MoE (#36878) * Initial commit for Qwen3 * fix and add tests for qwen3 & qwen3_moe * rename models for tests. * fix * fix * fix and add docs. * fix model name in docs. * simplify modular and fix configuration issues * Fix the red CI: ruff was updated * revert ruff, version was wrong * fix qwen3moe. * fix * make sure MOE can load * fix copies --------- Co-authored-by: Arthur Zucker --- docs/source/en/_toctree.yml | 4 + docs/source/en/index.md | 1 - docs/source/en/model_doc/qwen3.md | 59 + docs/source/en/model_doc/qwen3_moe.md | 58 + src/transformers/__init__.py | 40 + src/transformers/models/__init__.py | 2 + .../models/auto/configuration_auto.py | 4 + src/transformers/models/auto/modeling_auto.py | 10 + .../models/auto/tokenization_auto.py | 14 + .../qwen2_moe/configuration_qwen2_moe.py | 3 +- src/transformers/models/qwen3/__init__.py | 27 + .../models/qwen3/configuration_qwen3.py | 212 +++ .../models/qwen3/modeling_qwen3.py | 1202 ++++++++++++++ .../models/qwen3/modular_qwen3.py | 203 +++ src/transformers/models/qwen3_moe/__init__.py | 27 + .../qwen3_moe/configuration_qwen3_moe.py | 240 +++ .../models/qwen3_moe/modeling_qwen3_moe.py | 1412 +++++++++++++++++ .../models/qwen3_moe/modular_qwen3_moe.py | 376 +++++ src/transformers/utils/dummy_pt_objects.py | 84 + src/transformers/utils/fx.py | 2 + tests/models/qwen3/__init__.py | 0 tests/models/qwen3/test_modeling_qwen3.py | 1039 ++++++++++++ tests/models/qwen3_moe/__init__.py | 0 .../qwen3_moe/test_modeling_qwen3_moe.py | 631 ++++++++ .../modular/test_conversion_order.py | 2 + utils/check_config_docstrings.py | 1 + 26 files changed, 5650 insertions(+), 3 deletions(-) create mode 100644 docs/source/en/model_doc/qwen3.md create mode 100644 docs/source/en/model_doc/qwen3_moe.md create mode 100644 src/transformers/models/qwen3/__init__.py create mode 100644 src/transformers/models/qwen3/configuration_qwen3.py create mode 100644 src/transformers/models/qwen3/modeling_qwen3.py create mode 100644 src/transformers/models/qwen3/modular_qwen3.py create mode 100644 src/transformers/models/qwen3_moe/__init__.py create mode 100644 src/transformers/models/qwen3_moe/configuration_qwen3_moe.py create mode 100644 src/transformers/models/qwen3_moe/modeling_qwen3_moe.py create mode 100644 src/transformers/models/qwen3_moe/modular_qwen3_moe.py create mode 100644 tests/models/qwen3/__init__.py create mode 100644 tests/models/qwen3/test_modeling_qwen3.py create mode 100644 tests/models/qwen3_moe/__init__.py create mode 100644 tests/models/qwen3_moe/test_modeling_qwen3_moe.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 4dd3cc6d3d7..ed306fac947 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -603,6 +603,10 @@ title: Qwen2 - local: model_doc/qwen2_moe title: Qwen2MoE + - local: model_doc/qwen3 + title: Qwen3 + - local: model_doc/qwen3_moe + title: Qwen3MoE - local: model_doc/rag title: RAG - local: model_doc/realm diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 8120e12937a..5c3898ce788 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -43,4 +43,3 @@ Transformers is designed for developers and machine learning engineers and resea -Join us on the Hugging Face [Hub](https://huggingface.co/), [Discord](https://discord.com/invite/JfAtkvEtRb), or [forum](https://discuss.huggingface.co/) to collaborate and build models, datasets, and applications together. diff --git a/docs/source/en/model_doc/qwen3.md b/docs/source/en/model_doc/qwen3.md new file mode 100644 index 00000000000..e3f3c266091 --- /dev/null +++ b/docs/source/en/model_doc/qwen3.md @@ -0,0 +1,59 @@ + + +# Qwen3 + +## Overview + +To be released with the official model launch. + +### Model Details + +To be released with the official model launch. + + +## Usage tips + +To be released with the official model launch. + +## Qwen3Config + +[[autodoc]] Qwen3Config + +## Qwen3Model + +[[autodoc]] Qwen3Model + - forward + +## Qwen3ForCausalLM + +[[autodoc]] Qwen3ForCausalLM + - forward + +## Qwen3ForSequenceClassification + +[[autodoc]] Qwen3ForSequenceClassification + - forward + +## Qwen3ForTokenClassification + +[[autodoc]] Qwen3ForTokenClassification + - forward + +## Qwen3ForQuestionAnswering + +[[autodoc]] Qwen3ForQuestionAnswering + - forward diff --git a/docs/source/en/model_doc/qwen3_moe.md b/docs/source/en/model_doc/qwen3_moe.md new file mode 100644 index 00000000000..1de4af1a5bd --- /dev/null +++ b/docs/source/en/model_doc/qwen3_moe.md @@ -0,0 +1,58 @@ + + +# Qwen3MoE + +## Overview + +To be released with the official model launch. + +### Model Details + +To be released with the official model launch. + +## Usage tips + +To be released with the official model launch. + +## Qwen3MoeConfig + +[[autodoc]] Qwen3MoeConfig + +## Qwen3MoeModel + +[[autodoc]] Qwen3MoeModel + - forward + +## Qwen3MoeForCausalLM + +[[autodoc]] Qwen3MoeForCausalLM + - forward + +## Qwen3MoeForSequenceClassification + +[[autodoc]] Qwen3MoeForSequenceClassification + - forward + +## Qwen3MoeForTokenClassification + +[[autodoc]] Qwen3MoeForTokenClassification + - forward + +## Qwen3MoeForQuestionAnswering + +[[autodoc]] Qwen3MoeForQuestionAnswering + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 3892541b66d..3a11b3e36b7 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -744,6 +744,8 @@ _import_structure = { "Qwen2VLConfig", "Qwen2VLProcessor", ], + "models.qwen3": ["Qwen3Config"], + "models.qwen3_moe": ["Qwen3MoeConfig"], "models.rag": ["RagConfig", "RagRetriever", "RagTokenizer"], "models.recurrent_gemma": ["RecurrentGemmaConfig"], "models.reformer": ["ReformerConfig"], @@ -3441,6 +3443,26 @@ else: "Qwen2VLPreTrainedModel", ] ) + _import_structure["models.qwen3"].extend( + [ + "Qwen3ForCausalLM", + "Qwen3ForQuestionAnswering", + "Qwen3ForSequenceClassification", + "Qwen3ForTokenClassification", + "Qwen3Model", + "Qwen3PreTrainedModel", + ] + ) + _import_structure["models.qwen3_moe"].extend( + [ + "Qwen3MoeForCausalLM", + "Qwen3MoeForQuestionAnswering", + "Qwen3MoeForSequenceClassification", + "Qwen3MoeForTokenClassification", + "Qwen3MoeModel", + "Qwen3MoePreTrainedModel", + ] + ) _import_structure["models.rag"].extend( [ "RagModel", @@ -5993,6 +6015,8 @@ if TYPE_CHECKING: Qwen2VLConfig, Qwen2VLProcessor, ) + from .models.qwen3 import Qwen3Config + from .models.qwen3_moe import Qwen3MoeConfig from .models.rag import RagConfig, RagRetriever, RagTokenizer from .models.recurrent_gemma import RecurrentGemmaConfig from .models.reformer import ReformerConfig @@ -8293,6 +8317,22 @@ if TYPE_CHECKING: Qwen2VLModel, Qwen2VLPreTrainedModel, ) + from .models.qwen3 import ( + Qwen3ForCausalLM, + Qwen3ForQuestionAnswering, + Qwen3ForSequenceClassification, + Qwen3ForTokenClassification, + Qwen3Model, + Qwen3PreTrainedModel, + ) + from .models.qwen3_moe import ( + Qwen3MoeForCausalLM, + Qwen3MoeForQuestionAnswering, + Qwen3MoeForSequenceClassification, + Qwen3MoeForTokenClassification, + Qwen3MoeModel, + Qwen3MoePreTrainedModel, + ) from .models.rag import ( RagModel, RagPreTrainedModel, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 64041f7f0fe..b2cde9f4bc5 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -230,6 +230,8 @@ from . import ( qwen2_audio, qwen2_moe, qwen2_vl, + qwen3, + qwen3_moe, rag, recurrent_gemma, reformer, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 03ee0a5a2f0..f5673e09cb7 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -254,6 +254,8 @@ CONFIG_MAPPING_NAMES = OrderedDict( ("qwen2_audio_encoder", "Qwen2AudioEncoderConfig"), ("qwen2_moe", "Qwen2MoeConfig"), ("qwen2_vl", "Qwen2VLConfig"), + ("qwen3", "Qwen3Config"), + ("qwen3_moe", "Qwen3MoeConfig"), ("rag", "RagConfig"), ("realm", "RealmConfig"), ("recurrent_gemma", "RecurrentGemmaConfig"), @@ -609,6 +611,8 @@ MODEL_NAMES_MAPPING = OrderedDict( ("qwen2_audio_encoder", "Qwen2AudioEncoder"), ("qwen2_moe", "Qwen2MoE"), ("qwen2_vl", "Qwen2VL"), + ("qwen3", "Qwen3"), + ("qwen3_moe", "Qwen3MoE"), ("rag", "RAG"), ("realm", "REALM"), ("recurrent_gemma", "RecurrentGemma"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index c98ac16988a..7498f6d22c1 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -233,6 +233,8 @@ MODEL_MAPPING_NAMES = OrderedDict( ("qwen2_audio_encoder", "Qwen2AudioEncoder"), ("qwen2_moe", "Qwen2MoeModel"), ("qwen2_vl", "Qwen2VLModel"), + ("qwen3", "Qwen3Model"), + ("qwen3_moe", "Qwen3MoeModel"), ("recurrent_gemma", "RecurrentGemmaModel"), ("reformer", "ReformerModel"), ("regnet", "RegNetModel"), @@ -576,6 +578,8 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ("qdqbert", "QDQBertLMHeadModel"), ("qwen2", "Qwen2ForCausalLM"), ("qwen2_moe", "Qwen2MoeForCausalLM"), + ("qwen3", "Qwen3ForCausalLM"), + ("qwen3_moe", "Qwen3MoeForCausalLM"), ("recurrent_gemma", "RecurrentGemmaForCausalLM"), ("reformer", "ReformerModelWithLMHead"), ("rembert", "RemBertForCausalLM"), @@ -1072,6 +1076,8 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("qdqbert", "QDQBertForSequenceClassification"), ("qwen2", "Qwen2ForSequenceClassification"), ("qwen2_moe", "Qwen2MoeForSequenceClassification"), + ("qwen3", "Qwen3ForSequenceClassification"), + ("qwen3_moe", "Qwen3MoeForSequenceClassification"), ("reformer", "ReformerForSequenceClassification"), ("rembert", "RemBertForSequenceClassification"), ("roberta", "RobertaForSequenceClassification"), @@ -1153,6 +1159,8 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( ("qdqbert", "QDQBertForQuestionAnswering"), ("qwen2", "Qwen2ForQuestionAnswering"), ("qwen2_moe", "Qwen2MoeForQuestionAnswering"), + ("qwen3", "Qwen3ForQuestionAnswering"), + ("qwen3_moe", "Qwen3MoeForQuestionAnswering"), ("reformer", "ReformerForQuestionAnswering"), ("rembert", "RemBertForQuestionAnswering"), ("roberta", "RobertaForQuestionAnswering"), @@ -1257,6 +1265,8 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("qdqbert", "QDQBertForTokenClassification"), ("qwen2", "Qwen2ForTokenClassification"), ("qwen2_moe", "Qwen2MoeForTokenClassification"), + ("qwen3", "Qwen3ForTokenClassification"), + ("qwen3_moe", "Qwen3MoeForTokenClassification"), ("rembert", "RemBertForTokenClassification"), ("roberta", "RobertaForTokenClassification"), ("roberta-prelayernorm", "RobertaPreLayerNormForTokenClassification"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 7756268821e..42ff109ccab 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -454,6 +454,20 @@ else: ), ), ("qwen2_vl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)), + ( + "qwen3", + ( + "Qwen2Tokenizer", + "Qwen2TokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "qwen3_moe", + ( + "Qwen2Tokenizer", + "Qwen2TokenizerFast" if is_tokenizers_available() else None, + ), + ), ("rag", ("RagTokenizer", None)), ("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)), ( diff --git a/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py b/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py index e0b1c122529..113f6c09eae 100644 --- a/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py @@ -26,8 +26,7 @@ class Qwen2MoeConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Qwen2MoeModel`]. It is used to instantiate a Qwen2MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of - Qwen1.5-MoE-A2.7B" [Qwen/Qwen1.5-MoE-A2.7B"](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B"). + with the defaults will yield a similar configuration to that of [Qwen/Qwen1.5-MoE-A2.7B](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B). Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. diff --git a/src/transformers/models/qwen3/__init__.py b/src/transformers/models/qwen3/__init__.py new file mode 100644 index 00000000000..ae1ec7ea525 --- /dev/null +++ b/src/transformers/models/qwen3/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_qwen3 import * + from .modeling_qwen3 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/qwen3/configuration_qwen3.py b/src/transformers/models/qwen3/configuration_qwen3.py new file mode 100644 index 00000000000..06e527ce53f --- /dev/null +++ b/src/transformers/models/qwen3/configuration_qwen3.py @@ -0,0 +1,212 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen3 model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Qwen3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3Model`]. It is used to instantiate a + Qwen3 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen3-8B [Qwen/Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen3 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen3Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 22016): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + head_dim (`int`, *optional*, defaults to 128): + The attention head dimension. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import Qwen3Model, Qwen3Config + + >>> # Initializing a Qwen3 style configuration + >>> configuration = Qwen3Config() + + >>> # Initializing a model from the Qwen3-8B style configuration + >>> model = Qwen3Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Qwen3` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + head_dim=128, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window # we check `use_sliding_window` in the modeling code + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["Qwen3Config"] diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py new file mode 100644 index 00000000000..bff2c96667d --- /dev/null +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -0,0 +1,1202 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen3/modular_qwen3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Tuple, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + LossKwargs, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ...utils.deprecation import deprecate_kwarg +from .configuration_qwen3 import Qwen3Config + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "Qwen/Qwen3-8B" +_CONFIG_FOR_DOC = "Qwen3Config" + + +class Qwen3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen3MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Qwen3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen3Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! + self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape + self.sliding_window = config.sliding_window + if not ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + self.sliding_window = None + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen3DecoderLayer(nn.Module): + def __init__(self, config: Qwen3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx) + self.mlp = Qwen3MLP(config) + self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if ( + config.sliding_window and config._attn_implementation != "flash_attention_2" + ): # diff with Llama is this warning + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class Qwen3RotaryEmbedding(nn.Module): + def __init__(self, config: Qwen3Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +QWEN3_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen3Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen3 Model outputting raw hidden-states without any specific head on top.", + QWEN3_START_DOCSTRING, +) +class Qwen3PreTrainedModel(PreTrainedModel): + config_class = Qwen3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen3DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +QWEN3_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Qwen3 Model outputting raw hidden-states without any specific head on top.", + QWEN3_START_DOCSTRING, +) +class Qwen3Model(Qwen3PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen3DecoderLayer`] + + Args: + config: Qwen3Config + """ + + def __init__(self, config: Qwen3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Qwen3. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen3Config, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen3Config`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Qwen3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen3ForCausalLM + + >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Qwen3 Model transformer with a sequence classification head on top (linear layer). + + [`Qwen3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + QWEN3_START_DOCSTRING, +) +class Qwen3ForSequenceClassification(Qwen3PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen3Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Qwen3 Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + QWEN3_START_DOCSTRING, +) +class Qwen3ForTokenClassification(Qwen3PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen3Model(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Qwen3 Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + QWEN3_START_DOCSTRING, +) +class Qwen3ForQuestionAnswering(Qwen3PreTrainedModel): + base_model_prefix = "transformer" + + def __init__(self, config): + super().__init__(config) + self.transformer = Qwen3Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "Qwen3ForCausalLM", + "Qwen3ForQuestionAnswering", + "Qwen3Model", + "Qwen3PreTrainedModel", + "Qwen3ForSequenceClassification", + "Qwen3ForTokenClassification", +] diff --git a/src/transformers/models/qwen3/modular_qwen3.py b/src/transformers/models/qwen3/modular_qwen3.py new file mode 100644 index 00000000000..58c02036b6e --- /dev/null +++ b/src/transformers/models/qwen3/modular_qwen3.py @@ -0,0 +1,203 @@ +# coding=utf-8 +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen3 model.""" + +from typing import Callable, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint + +from ...cache_utils import Cache +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import CausalLMOutputWithPast +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import ( + LossKwargs, + logging, +) +from ..gemma.modeling_gemma import GemmaMLP +from ..llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaForQuestionAnswering, + LlamaForSequenceClassification, + LlamaForTokenClassification, + LlamaRMSNorm, + apply_rotary_pos_emb, + eager_attention_forward, +) +from ..mistral.modeling_mistral import MistralModel +from .configuration_qwen3 import Qwen3Config + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "Qwen/Qwen3-8B" + + +class Qwen3RMSNorm(LlamaRMSNorm): + pass + + +class Qwen3MLP(GemmaMLP): + pass + + +class Qwen3Attention(LlamaAttention): + def __init__(self, config: Qwen3Config, layer_idx: int): + super().__init__(config, layer_idx) + self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! + self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape + self.sliding_window = config.sliding_window + if not ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + self.sliding_window = None + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen3DecoderLayer(LlamaDecoderLayer): + def __init__(self, config: Qwen3Config, layer_idx: int): + super().__init__() + self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx) + self.mlp = Qwen3MLP(config) + if ( + config.sliding_window and config._attn_implementation != "flash_attention_2" + ): # diff with Llama is this warning + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + + +class Qwen3Model(MistralModel): # mistral model creates sliding window + pass + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class Qwen3ForCausalLM(LlamaForCausalLM): + def forward( + self, + **super_kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen3ForCausalLM + + >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + return super().forward(**super_kwargs) + + +class Qwen3ForSequenceClassification(LlamaForSequenceClassification): + pass + + +class Qwen3ForTokenClassification(LlamaForTokenClassification): + pass + + +class Qwen3ForQuestionAnswering(LlamaForQuestionAnswering): + pass + + +__all__ = [ + "Qwen3ForCausalLM", + "Qwen3ForQuestionAnswering", + "Qwen3Model", + "Qwen3PreTrainedModel", # noqa: F822 + "Qwen3ForSequenceClassification", + "Qwen3ForTokenClassification", +] diff --git a/src/transformers/models/qwen3_moe/__init__.py b/src/transformers/models/qwen3_moe/__init__.py new file mode 100644 index 00000000000..715effd8b7f --- /dev/null +++ b/src/transformers/models/qwen3_moe/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_qwen3_moe import * + from .modeling_qwen3_moe import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/qwen3_moe/configuration_qwen3_moe.py b/src/transformers/models/qwen3_moe/configuration_qwen3_moe.py new file mode 100644 index 00000000000..0997bbd7761 --- /dev/null +++ b/src/transformers/models/qwen3_moe/configuration_qwen3_moe.py @@ -0,0 +1,240 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen3MoE model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Qwen3MoeConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3MoeModel`]. It is used to instantiate a + Qwen3MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of [Qwen/Qwen3-MoE-15B-A2B](https://huggingface.co/Qwen/Qwen3-15B-A2B). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen3MoE model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen3MoeModel`] + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 6144): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + decoder_sparse_step (`int`, *optional*, defaults to 1): + The frequency of the MoE layer. + moe_intermediate_size (`int`, *optional*, defaults to 768): + Intermediate size of the routed expert. + num_experts_per_tok (`int`, *optional*, defaults to 8): + Number of selected experts. + num_experts (`int`, *optional*, defaults to 128): + Number of routed experts. + norm_topk_prob (`bool`, *optional*, defaults to `False`): + Whether to normalize the topk probabilities. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabeling this will also + allow the model to output the auxiliary loss, including load balancing loss and router z-loss. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + mlp_only_layers (`List[int]`, *optional*, defaults to `[]`): + Indicate which layers use Qwen3MoeMLP rather than Qwen3MoeSparseMoeBlock + The list contains layer index, from 0 to num_layers-1 if we have num_layers layers + If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity. + + ```python + >>> from transformers import Qwen3MoeModel, Qwen3MoeConfig + + >>> # Initializing a Qwen3MoE style configuration + >>> configuration = Qwen3MoeConfig() + + >>> # Initializing a model from the Qwen3-15B-A2B" style configuration + >>> model = Qwen3MoeModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_moe" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Qwen3Moe` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=151936, + hidden_size=2048, + intermediate_size=6144, + num_hidden_layers=24, + num_attention_heads=32, + num_key_value_heads=4, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + decoder_sparse_step=1, + moe_intermediate_size=768, + num_experts_per_tok=8, + num_experts=128, + norm_topk_prob=False, + output_router_logits=False, + router_aux_loss_coef=0.001, + mlp_only_layers=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if use_sliding_window else None + self.max_window_layers = max_window_layers + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + # MoE arguments + self.decoder_sparse_step = decoder_sparse_step + self.moe_intermediate_size = moe_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.norm_topk_prob = norm_topk_prob + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["Qwen3MoeConfig"] diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py new file mode 100644 index 00000000000..117356dffd4 --- /dev/null +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -0,0 +1,1412 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen3_moe/modular_qwen3_moe.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen3_moe.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + LossKwargs, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ...utils.deprecation import deprecate_kwarg +from .configuration_qwen3_moe import Qwen3MoeConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "Qwen/Qwen3-MoE-15B-A2B" +_CONFIG_FOR_DOC = "Qwen3MoeConfig" + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Qwen3MoeAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen3MoeConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! + self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape + self.sliding_window = config.sliding_window + if not ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + self.sliding_window = None + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen3MoeMLP(nn.Module): + def __init__(self, config, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Qwen3MoeSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + + # gating + self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.experts = nn.ModuleList( + [Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)] + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + if self.norm_topk_prob: # only diff with mixtral sparse moe block! + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class Qwen3MoeRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen3MoeRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen3MoeDecoderLayer(nn.Module): + def __init__(self, config: Qwen3MoeConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen3MoeAttention(config, layer_idx) + self.mlp = Qwen3MoeMLP(config) + + self.self_attn = Qwen3MoeAttention(config, layer_idx) + + if (layer_idx not in config.mlp_only_layers) and ( + config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 + ): + self.mlp = Qwen3MoeSparseMoeBlock(config) + else: + self.mlp = Qwen3MoeMLP(config, intermediate_size=config.intermediate_size) + + self.input_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, + and should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states = self.mlp(hidden_states) + if isinstance(hidden_states, tuple): + hidden_states, router_logits = hidden_states + else: + router_logits = None + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +class Qwen3MoeRotaryEmbedding(nn.Module): + def __init__(self, config: Qwen3MoeConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +QWEN3_MOE_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen3MoeConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen3Moe Model outputting raw hidden-states without any specific head on top.", + QWEN3_MOE_START_DOCSTRING, +) +class Qwen3MoePreTrainedModel(PreTrainedModel): + config_class = Qwen3MoeConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen3MoeDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +QWEN3_MOE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Qwen3Moe Model outputting raw hidden-states without any specific head on top.", + QWEN3_MOE_START_DOCSTRING, +) +class Qwen3MoeModel(Qwen3MoePreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen3MoeDecoderLayer`] + + Args: + config: Qwen3MoeConfig + """ + + def __init__(self, config: Qwen3MoeConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen3MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + return output if return_dict else output.to_tuple() + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Qwen3Moe. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen3MoeConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen3MoeConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +class Qwen3MoeForCausalLM(Qwen3MoePreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Qwen3MoeModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_experts + self.num_experts_per_tok = config.num_experts_per_tok + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM + + >>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +@add_start_docstrings( + """ + The Qwen3Moe Model transformer with a sequence classification head on top (linear layer). + + [`Qwen3MoeForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + QWEN3_MOE_START_DOCSTRING, +) +class Qwen3MoeForSequenceClassification(Qwen3MoePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen3MoeModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Qwen3Moe Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + QWEN3_MOE_START_DOCSTRING, +) +class Qwen3MoeForTokenClassification(Qwen3MoePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen3MoeModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Qwen3Moe Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + QWEN3_MOE_START_DOCSTRING, +) +class Qwen3MoeForQuestionAnswering(Qwen3MoePreTrainedModel): + base_model_prefix = "transformer" + + def __init__(self, config): + super().__init__(config) + self.transformer = Qwen3MoeModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "Qwen3MoeForCausalLM", + "Qwen3MoeForQuestionAnswering", + "Qwen3MoeModel", + "Qwen3MoePreTrainedModel", + "Qwen3MoeForSequenceClassification", + "Qwen3MoeForTokenClassification", +] diff --git a/src/transformers/models/qwen3_moe/modular_qwen3_moe.py b/src/transformers/models/qwen3_moe/modular_qwen3_moe.py new file mode 100644 index 00000000000..31969dc8857 --- /dev/null +++ b/src/transformers/models/qwen3_moe/modular_qwen3_moe.py @@ -0,0 +1,376 @@ +# coding=utf-8 +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen3 model.""" + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import MoeCausalLMOutputWithPast +from ...processing_utils import Unpack +from ...utils import ( + LossKwargs, + logging, +) +from ..llama.modeling_llama import ( + LlamaForQuestionAnswering, + LlamaForSequenceClassification, + LlamaForTokenClassification, + LlamaRMSNorm, +) +from ..mixtral.modeling_mixtral import ( + MixtralForCausalLM, + MixtralModel, + load_balancing_loss_func, +) +from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeDecoderLayer +from ..qwen3.modeling_qwen3 import Qwen3Attention +from .configuration_qwen3_moe import Qwen3MoeConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "Qwen/Qwen3-MoE-15B-A2B" + + +class Qwen3MoeAttention(Qwen3Attention): # This is the main diff with qwen2Moe! + pass + + +class Qwen3MoeMLP(nn.Module): + def __init__(self, config, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Qwen3MoeSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + + # gating + self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.experts = nn.ModuleList( + [Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)] + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + if self.norm_topk_prob: # only diff with mixtral sparse moe block! + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class Qwen3MoeRMSNorm(LlamaRMSNorm): + pass + + +class Qwen3MoeDecoderLayer(Qwen2MoeDecoderLayer, nn.Module): + def __init__(self, config: Qwen3MoeConfig, layer_idx: int): + nn.Module().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen3MoeAttention(config, layer_idx) + self.mlp = Qwen3MoeMLP(config) + + self.self_attn = Qwen3MoeAttention(config, layer_idx) + + if (layer_idx not in config.mlp_only_layers) and ( + config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 + ): + self.mlp = Qwen3MoeSparseMoeBlock(config) + else: + self.mlp = Qwen3MoeMLP(config, intermediate_size=config.intermediate_size) + + self.input_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, + and should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states = self.mlp(hidden_states) + if isinstance(hidden_states, tuple): + hidden_states, router_logits = hidden_states + else: + router_logits = None + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +class Qwen3MoeModel(MixtralModel): + def __init__(self, config: Qwen3MoeConfig): + super().__init__(config) + self.layers = nn.ModuleList( + [Qwen3MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class Qwen3MoeForCausalLM(MixtralForCausalLM): + def __init__(self, config): + super().__init__(config) + self.model = Qwen3MoeModel(config) + self.num_experts = config.num_experts + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM + + >>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +class Qwen3MoeForSequenceClassification(LlamaForSequenceClassification): + pass + + +class Qwen3MoeForTokenClassification(LlamaForTokenClassification): + pass + + +class Qwen3MoeForQuestionAnswering(LlamaForQuestionAnswering): + pass + + +__all__ = [ + "Qwen3MoeForCausalLM", + "Qwen3MoeForQuestionAnswering", + "Qwen3MoeModel", + "Qwen3MoePreTrainedModel", # noqa: F822 + "Qwen3MoeForSequenceClassification", + "Qwen3MoeForTokenClassification", +] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 63be1917d0e..7d980c739a0 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -8222,6 +8222,90 @@ class Qwen2VLPreTrainedModel(metaclass=DummyObject): requires_backends(self, ["torch"]) +class Qwen3ForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Qwen3ForQuestionAnswering(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Qwen3ForSequenceClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Qwen3ForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Qwen3Model(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Qwen3PreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Qwen3MoeForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Qwen3MoeForQuestionAnswering(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Qwen3MoeForSequenceClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Qwen3MoeForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Qwen3MoeModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Qwen3MoePreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class RagModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 45fa3d9ca68..6fafc3c8090 100755 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -158,6 +158,8 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [ "plbart", "qwen2", "qwen2_moe", + "qwen3", + "qwen3_moe", "resnet", "roberta", "segformer", diff --git a/tests/models/qwen3/__init__.py b/tests/models/qwen3/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/models/qwen3/test_modeling_qwen3.py b/tests/models/qwen3/test_modeling_qwen3.py new file mode 100644 index 00000000000..8490d6e2522 --- /dev/null +++ b/tests/models/qwen3/test_modeling_qwen3.py @@ -0,0 +1,1039 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Qwen3 model.""" + +import gc +import unittest + +import pytest +from packaging import version + +from transformers import AutoTokenizer, Qwen3Config, is_torch_available, set_seed +from transformers.generation.configuration_utils import GenerationConfig +from transformers.testing_utils import ( + backend_empty_cache, + require_bitsandbytes, + require_flash_attn, + require_torch, + require_torch_gpu, + require_torch_sdpa, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + Qwen3ForCausalLM, + Qwen3ForQuestionAnswering, + Qwen3ForSequenceClassification, + Qwen3ForTokenClassification, + Qwen3Model, + ) + + +class Qwen3ModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + max_window_layers=3, + use_sliding_window=True, + sliding_window=50, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=16, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + pad_token_id=0, + bos_token_id=1, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.max_window_layers = max_window_layers + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.scope = scope + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + return Qwen3Config( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + max_window_layers=self.max_window_layers, + use_sliding_window=self.use_sliding_window, + sliding_window=self.sliding_window, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + head_dim=self.head_dim, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + pad_token_id=self.pad_token_id, + bos_token_id=self.bos_token_id, + ) + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Qwen3 + def create_and_check_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = Qwen3Model(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model_as_decoder with Llama->Qwen3 + def create_and_check_model_as_decoder( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + model = Qwen3Model(config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + ) + result = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + ) + result = model(input_ids, attention_mask=input_mask) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_for_causal_lm with Llama->Qwen3 + def create_and_check_for_causal_lm( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + model = Qwen3ForCausalLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_decoder_model_past_large_inputs with Llama->Qwen3 + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.is_decoder = True + config.add_cross_attention = True + model = Qwen3ForCausalLM(config=config) + model.to(torch_device) + model.eval() + + # first forward pass + outputs = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_hidden_states=True, + )["hidden_states"][0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + )["hidden_states"][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen3 +class Qwen3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = ( + ( + Qwen3Model, + Qwen3ForCausalLM, + Qwen3ForSequenceClassification, + Qwen3ForTokenClassification, + Qwen3ForQuestionAnswering, + ) + if is_torch_available() + else () + ) + pipeline_model_mapping = ( + { + "feature-extraction": Qwen3Model, + "text-classification": Qwen3ForSequenceClassification, + "token-classification": Qwen3ForTokenClassification, + "text-generation": Qwen3ForCausalLM, + "zero-shot": Qwen3ForSequenceClassification, + "question-answering": Qwen3ForQuestionAnswering, + } + if is_torch_available() + else {} + ) + test_headmasking = False + test_pruning = False + fx_compatible = False # Broken by attention refactor cc @Cyrilvallez + + # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 + def is_pipeline_test_to_skip( + self, + pipeline_test_case_name, + config_class, + model_architecture, + tokenizer_name, + image_processor_name, + feature_extractor_name, + processor_name, + ): + return True + + def setUp(self): + self.model_tester = Qwen3ModelTester(self) + self.config_tester = ConfigTester(self, config_class=Qwen3Config, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_various_embeddings(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + for type in ["absolute", "relative_key", "relative_key_query"]: + config_and_inputs[0].position_embedding_type = type + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_torch_fx_output_loss(self): + super().test_torch_fx_output_loss() + + def test_Qwen3_sequence_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + print(config) + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + model = Qwen3ForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + def test_Qwen3_sequence_classification_model_for_single_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "single_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + model = Qwen3ForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + def test_Qwen3_sequence_classification_model_for_multi_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "multi_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor( + [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size + ).to(torch.float) + model = Qwen3ForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Qwen3,llama->Qwen3 + def test_Qwen3_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + model = Qwen3ForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=token_labels) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + + @unittest.skip(reason="Qwen3 buffers include complex numbers, which breaks this test") + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip(reason="Qwen3 uses GQA on all models so the KV cache is a non standard format") + def test_past_key_values_format(self): + pass + + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_inference_equivalence_right_padding(self): + self.skipTest(reason="Qwen3 flash attention does not support right padding") + + +@require_torch +class Qwen3IntegrationTest(unittest.TestCase): + @slow + def test_model_600m_logits(self): + input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] + model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-0.6B-Base", device_map="auto") + input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) + with torch.no_grad(): + out = model(input_ids).logits.float().cpu() + # Expected mean on dim = -1 + EXPECTED_MEAN = torch.tensor([[-1.4577, 1.3261, 3.8498, 3.4229, 2.9009, 1.8813, 2.1530, 2.1431]]) + torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2) + # slicing logits[0, 0, 0:30] + EXPECTED_SLICE = torch.tensor([5.9062, 6.0938, 5.5625, 3.8594, 2.6094, 1.9531, 4.3125, 4.9375, 3.8906, 3.1094, 3.6719, 5.1562, 6.9062, 5.7500, 5.4062, 7.0625, 8.7500, 8.7500, 8.1250, 7.9375, 8.0625, 7.5312, 7.3750, 7.2188, 7.2500, 5.8750, 2.8750, 4.3438, 2.3438, 2.2500]) # fmt: skip + print(out[0, 0, :30]) + torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-4, atol=1e-4) + + del model + backend_empty_cache(torch_device) + gc.collect() + + @slow + def test_model_600m_generation(self): + EXPECTED_TEXT_COMPLETION = """My favourite condiment is 100% plain, unflavoured, and unadulterated. It is""" + prompt = "My favourite condiment is " + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B-Base", use_fast=False) + model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-0.6B-Base", device_map="auto") + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device) + + # greedy generation outputs + generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + del model + backend_empty_cache(torch_device) + gc.collect() + + @require_bitsandbytes + @slow + @require_flash_attn + @pytest.mark.flash_attn_test + def test_model_600m_long_prompt(self): + EXPECTED_OUTPUT_TOKEN_IDS = [306, 338] + # An input with 4097 tokens that is above the size of the sliding window + input_ids = [1] + [306, 338] * 2048 + model = Qwen3ForCausalLM.from_pretrained( + "Qwen/Qwen3-0.6B-Base", + device_map="auto", + load_in_4bit=True, + attn_implementation="flash_attention_2", + ) + input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) + generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + + # Assisted generation + assistant_model = model + assistant_model.generation_config.num_assistant_tokens = 2 + assistant_model.generation_config.num_assistant_tokens_schedule = "constant" + generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + + del assistant_model + del model + backend_empty_cache(torch_device) + gc.collect() + + @slow + @require_torch_sdpa + def test_model_600m_long_prompt_sdpa(self): + EXPECTED_OUTPUT_TOKEN_IDS = [306, 338] + # An input with 4097 tokens that is above the size of the sliding window + input_ids = [1] + [306, 338] * 2048 + model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-0.6B-Base", device_map="auto", attn_implementation="sdpa") + input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) + generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + + # Assisted generation + assistant_model = model + assistant_model.generation_config.num_assistant_tokens = 2 + assistant_model.generation_config.num_assistant_tokens_schedule = "constant" + generated_ids = assistant_model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + + del assistant_model + + backend_empty_cache(torch_device) + gc.collect() + + EXPECTED_TEXT_COMPLETION = "My favourite condiment is 100% plain, unflavoured, and unadulterated. It is" + prompt = "My favourite condiment is " + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B-Base", use_fast=False) + + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device) + + # greedy generation outputs + generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + @slow + def test_speculative_generation(self): + EXPECTED_TEXT_COMPLETION = ( + "My favourite condiment is 100% peanut butter. I love it so much that I can't help but use it" + ) + prompt = "My favourite condiment is " + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B-Base", use_fast=False) + model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-0.6B-Base", device_map="auto", torch_dtype=torch.float16) + assistant_model = Qwen3ForCausalLM.from_pretrained( + "Qwen/Qwen3-0.6B-Base", device_map="auto", torch_dtype=torch.float16 + ) + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device) + + # greedy generation outputs + set_seed(0) + generated_ids = model.generate( + input_ids, max_new_tokens=20, do_sample=True, temperature=0.3, assistant_model=assistant_model + ) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + del model + backend_empty_cache(torch_device) + gc.collect() + + @slow + def test_export_static_cache(self): + if version.parse(torch.__version__) < version.parse("2.4.0"): + self.skipTest(reason="This test requires torch >= 2.4 to run.") + + from transformers.integrations.executorch import ( + TorchExportableModuleWithStaticCache, + convert_and_export_with_cache, + ) + + qwen_model = "Qwen/Qwen3-0.6B-Base" + + tokenizer = AutoTokenizer.from_pretrained(qwen_model, pad_token="", padding_side="right") + EXPECTED_TEXT_COMPLETION = ["My favourite condiment is 100% plain, unflavoured, and unadulterated. It is"] + max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[ + "input_ids" + ].shape[-1] + + # Load model + device = "cpu" + dtype = torch.bfloat16 + cache_implementation = "static" + attn_implementation = "sdpa" + batch_size = 1 + model = Qwen3ForCausalLM.from_pretrained( + qwen_model, + device_map=device, + torch_dtype=dtype, + attn_implementation=attn_implementation, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_generation_length, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_generation_length, + }, + ), + ) + + prompt = ["My favourite condiment is "] + prompt_tokens = tokenizer(prompt, return_tensors="pt", padding=True).to(model.device) + prompt_token_ids = prompt_tokens["input_ids"] + max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] + + # Static Cache + export + exported_program = convert_and_export_with_cache(model) + ep_generated_ids = TorchExportableModuleWithStaticCache.generate( + exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens + ) + ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text) + + @require_flash_attn + @slow + def test_600m_generation(self): + model_id = "Qwen/Qwen3-0.6B-Base" + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = Qwen3ForCausalLM.from_pretrained( + model_id, use_sliding_window=True, max_window_layers=28, sliding_window=2048 + ).to(torch_device) + # we need a long text to test sliding window + # fmt: off + LONG_TEXT = """The Warring States period in Chinese history (c. 475 – 221 BC) comprises the final centuries of the Zhou dynasty (c. 1046 – 256 BC), which were characterized by warfare, bureaucratic and military reform, and political consolidation. It followed the Spring and Autumn period and concluded with the wars of conquest that saw the state of Qin annex each of the other contender states by 221 BC and found the Qin dynasty, the first imperial dynastic state in East Asian history. + + +While scholars have identified several different dates as marking the beginning of the Warring States period, Sima Qian's choice of 475 BC is the most often cited. The era largely corresponds to the second half of the Eastern Zhou period, where the king of Zhou formally ruled as Chinese sovereign, but had lost political power and functioned in practice as a figurehead. This dynamic served as the backdrop for the machinations of the eponymous Warring States. The label "Warring States period" derives from the Record of the Warring States, a work of history compiled during the early Han dynasty (202 BC – 220 AD). + + +Geography + +The political geography of the era was dominated by the Seven Warring States, namely: + + +Besides these seven major states other smaller states survived into the period. They include: + + +Periodisation + +The eastward flight of the Zhou court in 771 BC marks the start of the Spring and Autumn period. No one single incident or starting point inaugurated the Warring States era. The political situation of the period represented a culmination of historical trends of conquest and annexation which also characterised the Spring and Autumn period. As a result, there is some controversy as to the beginning of the era. Proposed starting points include: + + +History + +Background and formation + +The Eastern Zhou dynasty began its fall around 5th century BC. As their influence waned, they had to rely on armies in allied states rather than their own military force. Hundreds of smaller polities coalesced into seven major states which included: Chu, Han, Qin, Wei, Yan, Qi and Zhao. However, there eventually was a shift in alliances because each state's ruler wanted independence. This caused hundreds of wars between 535 and 286 BC. The victorious state would have overall rule and control in China. + + +The system of feudal states created by the Western Zhou dynasty underwent enormous changes after 771 BC with the flight of the Zhou court to modern-day Luoyang and the diminution of its relevance and power. The Spring and Autumn period led to a few states gaining power at the expense of many others, the latter no longer able to depend on central authority for legitimacy or protection. During the Warring States period, many rulers claimed the Mandate of Heaven to justify their conquest of other states and spread their influence. + + +The struggle for hegemony eventually created a state system dominated by several large states, such as Jin, Chu, Qin, Yan, and Qi, while the smaller states of the Central Plain tended to be their satellites and tributaries. Other major states also existed, such as Wu and Yue in the southeast. The last decades of the Spring and Autumn era were marked by increased stability, as the result of peace negotiations between Jin and Chu which established their respective spheres of influence. This situation ended with the partition of Jin, whereby the state was divided between the houses of Han, Zhao and Wei, leading to the seven major warring states. + + +Partition of Jin (453–403 BC) + +The rulers of Jin had steadily lost political powers since the middle of the 6th century BC to their nominally subordinate nobles and military commanders, a situation arising from the traditions of the Jin which forbade the enfeoffment of relatives of the ducal house. This allowed other clans to gain fiefs and military authority, and decades of internecine struggle led to the establishment of four major families, the Han, Zhao, Wei and Zhi. + + +The Battle of Jinyang saw the allied Han, Zhao and Wei destroy the Zhi family (453 BC) and their lands were distributed among them. With this, they became the de facto rulers of most of Jin's territory, though this situation would not be officially recognised until half a century later. The Jin division created a political vacuum that enabled during the first 50 years expansion of Chu and Yue northward and Qi southward. Qin increased its control of the local tribes and began its expansion southwest to Sichuan. + + +Early Warring States + +The three Jins recognized (403–364 BC) + +In 403 BC, the court of King Weilie of Zhou officially recognized Zhao, Wei and Han as immediate vassals, thereby raising them to the same rank as the other warring states. + + +From before 405 until 383 BC the three Jins were united under the leadership of Wei and expanded in all directions. The most important figure was Marquess Wen of Wei (445–396 BC). In 408–406 BC he conquered the State of Zhongshan to the northeast on the other side of Zhao. At the same time he pushed west across the Yellow River to the Luo River taking the area of Xihe (literally 'west of the river'). + + +The growing power of Wei caused Zhao to back away from the alliance. In 383 BC it moved its capital to Handan and attacked the small state of Wey. Wey appealed to Wei which attacked Zhao on the western side. Being in danger, Zhao called in Chu. As usual, Chu used this as a pretext to annex territory to its north, but the diversion allowed Zhao to occupy a part of Wei. This conflict marked the end of the power of the united Jins and the beginning a period of shifting alliances and wars on several fronts. + + +In 376 BC, the states of Han, Wei and Zhao deposed Duke Jing of Jin and divided the last remaining Jin territory between themselves, which marked the final end of the Jin state. + + +In 370 BC, Marquess Wu of Wei died without naming a successor, which led to a war of succession. After three years of civil war, Zhao from the north and Han from the south invaded Wei. On the verge of conquering Wei, the leaders of Zhao and Han fell into disagreement about what to do with Wei, and both armies abruptly retreated. As a result, King Hui of Wei (still a Marquess at the time) was able to ascend the throne of Wei. + + +Zhao extended from the Shanxi plateau across the plain to the borders of Qi. Wei reached east to Qi, Lu, and Song. To the south, the weaker state of Han held the east–west part of the Yellow River valley, surrounded the Zhou royal domain at Luoyang and held an area north of Luoyang called Shangdang. + + +Qi resurgence under Tian (379–340 BC) + +Duke Kang of Qi died in 379 BC with no heir from the house of Jiang, which had ruled Qi since the state's founding. The throne instead passed to the future King Wei, from the house of Tian. The Tian had been very influential at court towards the end of Jiang rule, and now openly assumed power. + + +The new ruler set about reclaiming territories that had been lost to other states. He launched a successful campaign against Zhao, Wey and Wei, once again extending Qi territory to the Great Wall. Sima Qian writes that the other states were so awestruck that nobody dared attack Qi for more than 20 years. The demonstrated military prowess also had a calming effect on Qi's own population, which experienced great domestic tranquility during Wei's reign. + + +By the end of King Wei's reign, Qi had become the strongest of the states and proclaimed itself "king"; establishing independence from the Zhou dynasty (see below). + + +Wars of Wei + +King Hui of Wei (370–319 BC) set about restoring the state. In 362–359 BC he exchanged territories with Han and Zhao in order to make the boundaries of the three states more rational. + + +In 364 BC, Wei was defeated by Qin at the Battle of Shimen and was only saved by the intervention of Zhao. Qin won another victory in 362 BC. In 361 BC the Wei capital was moved east to Daliang to be out of the reach of Qin. + + +In 354 BC, King Hui of Wei started a large-scale attack on Zhao. By 353 BC, Zhao was losing badly and its capital, Handan, was under siege. The state of Qi intervened. The famous Qi strategist, Sun Bin the great-great-great-grandson of Sun Tzu, the author of the Art of War, proposed to attack the Wei capital while the Wei army was tied up besieging Zhao. The strategy was a success; the Wei army hastily moved south to protect its capital, was caught on the road and decisively defeated at the Battle of Guiling. The battle is remembered in the second of the Thirty-Six Stratagems, "besiege Wei, save Zhao"—meaning to attack a vulnerable spot to relieve pressure at another point. + + +Domestically, King Hui patronized philosophy and the arts, and is perhaps best remembered for hosting the Confucian philosopher Mencius at his court; their conversations form the first two chapters of the book which bears Meng Zi's name. + + +Dukes become kings + +Qi and Wei became kingdoms (344 BC) + +The title of king (wang, 王) was held by figurehead rulers of the Zhou dynasty, while the rulers of most states held the title of duke (gong, 公) or marquess (hou, 侯). A major exception was Chu, whose rulers were called kings since King Wu of Chu started using the title c. 703 BC. + + +In 344 BC the rulers of Qi and Wei mutually recognized each other as kings: King Wei of Qi and King Hui of Wei, in effect declaring their independence from the Zhou court. This marked a major turning point: unlike those in the Spring and Autumn period, the new generation of rulers ascending the thrones in the Warring States period would not entertain even the pretence of being vassals of the Zhou dynasty, instead proclaiming themselves fully independent kingdoms. + + +Shang Yang reforms Qin (356–338 BC) + +During the early Warring States period Qin generally avoided conflicts with the other states. This changed during the reign of Duke Xiao, when prime minister Shang Yang made centralizing and authoritarian reforms in accordance with his Legalist philosophy between the years 356 and 338 BC. + + +Shang introduced land reforms, privatized land, rewarded farmers who exceeded harvest quotas, enslaved farmers who failed to meet quotas, and used enslaved subjects as rewards for those who met government policies. As manpower was short in Qin relative to the other states at the time, Shang enacted policies to increase its manpower. As Qin peasants were recruited into the military, he encouraged active immigration of peasants from other states into Qin as a replacement workforce; this policy simultaneously increased the manpower of Qin and weakened the manpower of Qin's rivals. + + +Shang made laws forcing citizens to marry at a young age and passed tax laws to encourage raising multiple children. He also enacted policies to free convicts who worked in opening wastelands for agriculture. Shang abolished primogeniture and created a double tax on households that had more than one son living in the household, to break up large clans into nuclear families. Shang also moved the capital to reduce the influence of nobles on the administration. + + +The rise of Qin was recognized by the royal court, and in 343 BC the king conferred the title of Count (伯 Bó) on Duke Xiao. As was customary, a conference was hosted which the feudal lords attended, and during which the Son of Heaven bestowed the title. + + +After the reforms Qin became much more aggressive. In 340 Qin took land from Wèi after it had been defeated by Qi. In 316 Qin conquered Shu and Ba in Sichuan to the southwest. Development of this area took a long time but slowly added greatly to Qin's wealth and power. + + +Qin defeats Wei (341–340 BC) + +In 341 BC, Wei attacked Han. Qi allowed Han to be nearly defeated and then intervened. The generals from the Battle of Guiling met again (Sun Bin and Tian Ji versus Pang Juan), using the same tactic, attacking Wei's capital. Sun Bin feigned a retreat and then turned on the overconfident Wei troops and decisively defeated them at the Battle of Maling. After the battle all three of the Jin successor states appeared before King Xuan of Qi, pledging their loyalty. + + +In the following year Qin attacked the weakened Wei. Wei was devastatingly defeated and ceded a large part of its territory in return for truce. With Wei severely weakened, Qi and Qin became the dominant states in China. + + +Wei came to rely on Qi for protection, with King Hui of Wei meeting King Xuan of Qi on two occasions. After Hui's death, his successor King Xiang also established a good relationship with his Qi counterpart, with both promising to recognize the other as "king". + + +Chu conquers Yue (334 BC) + +Early in the Warring States period, Chu was one of the strongest states in China. The state rose to a new level of power around 389 BC when King Dao of Chu (楚悼王) named the famous reformer Wu Qi as his chancellor. + + +Chu rose to its peak in 334 BC, when it conquered Yue to its east on the Pacific coast. The series of events leading up to this began when Yue prepared to attack Qi to its north. The King of Qi sent an emissary who persuaded the King of Yue to attack Chu instead. Yue initiated a large-scale attack at Chu but was defeated by Chu's counter-attack. Chu then proceeded to conquer Yue. + + +Qin, Han and Yan became kingdoms (325–323 BC) + +King Xian of Zhou had attempted to use what little royal prerogative he had left by appointing the dukes Xian (384–362 BC), Xiao (361–338 BC) and Hui (338–311 BC) of Qin as hegemons, thereby in theory making Qin the chief ally of the court. + + +However, in 325 the confidence of Duke Hui grew so great that he proclaimed himself "king" of Qin; adopting the same title as the king of Zhou and thereby effectively proclaiming independence from the Zhou dynasty. King Hui of Qin was guided by his prime minister Zhang Yi, a prominent representative of the School of Diplomacy. + + +He was followed in 323 BC by King Xuanhui of Han and King Yi of Yan, as well as King Cuo of the minor state Zhongshan. In 318 BC even the ruler of Song, a relatively minor state, declared himself king. Uniquely, while King Wuling of Zhao had joined the other kings in declaring himself king, he retracted this order in 318 BC, after Zhao suffered a great defeat at the hands of Qin. + + +Partition of Zhou (314 BC) + +King Kao of Zhou had enfeoffed his younger brother as Duke Huan of Henan. Three generations later, this cadet branch of the royal house began calling themselves "dukes of East Zhou". + + +Upon the ascension of King Nan in 314, East Zhou became an independent state. The king came to reside in what became known as West Zhou. + + +Horizontal and vertical alliances (334–249 BC) + +Towards the end of the Warring States period, the state of Qin became disproportionately powerful compared with the other six states. As a result, the policies of the six states became overwhelmingly oriented towards dealing with the Qin threat, with two opposing schools of thought. One school advocated a 'vertical' or north–south alliance called hezong (合縱) in which the states would ally with each other to repel Qin. The other advocated a 'horizontal' or east–west alliance called lianheng (連橫{), in which a state would ally with Qin to participate in its ascendancy. + + +There were some initial successes in hezong, though mutual suspicions between allied states led to the breakdown of such alliances. Qin repeatedly exploited the horizontal alliance strategy to defeat the states one by one. During this period, many philosophers and tacticians travelled around the states, recommending that the rulers put their respective ideas into use. These "lobbyists", such as Su Qin, who advocated vertical alliances, and Zhang Yi, who advocated horizontal alliances, were famous for their tact and intellect, and were collectively known as the School of Diplomacy, whose Chinese name (縱橫家 'the school of the vertical and horizontal') was derived from the two opposing ideas. + + +Su Qin and the first vertical alliance (334–300 BC) + +Beginning in 334 BC the diplomat Su Qin spent years visiting the courts of Yan, Zhao, Han, Wei, Qi and Chu and persuaded them to form a united front against Qin. In 318 BC all states except Qi launched a joint attack on Qin, which was not successful. + + +King Hui of Qin died in 311 BC, followed by prime minister Zhang Yi one year later. The new monarch, King Wu, reigned only four years before dying without legitimate heirs. Some damaging turbulence ensued throughout 307 BC before a son of King Hui by a concubine (i.e. a younger half-brother of King Wu) could be established as King Zhao, who in stark contrast to his predecessor went on to rule for an unprecedented 53 years. + + +After the failure of the first vertical alliance, Su Qin eventually came to live in Qi, where he was favored by King Xuan and drew the envy of the ministers. An assassination attempt in 300 BC left Su mortally wounded but not dead. Sensing death approaching, he advised the newly crowned King Min have him publicly executed to draw out the assassins. King Min complied with Su's request and killed him, putting an end to the first generation of Vertical alliance thinkers. + + +The first horizontal alliance (300–287 BC) + +King Min of Qi came to be highly influenced by Lord Mengchang, a grandson of the former King Wei of Qi. Lord Mengchang made a westward alliance with the states of Wei and Han. In the far west, Qin, which had been weakened by a succession struggle in 307, yielded to the new coalition and appointed Lord Mengchang its chief minister. The alliance between Qin and Qi was sealed by a Qin princess marrying King Min. This horizontal or east–west alliance might have secured peace except that it excluded the State of Zhao. + + +Around 299 BC, the ruler of Zhao became the last of the seven major states to proclaim himself "king". + + +In 298 BC, Zhao offered Qin an alliance and Lord Mengchang was driven out of Qin. The remaining three allies, Qi, Wei and Han, attacked Qin, driving up the Yellow River below Shanxi to the Hangu Pass. After 3 years of fighting they took the pass and forced Qin to return territory to Han and Wei. They next inflicted major defeats on Yan and Chu. During the 5-year administration of Lord Mengchang, Qi was the major power in China. + + +In 294, Lord Mengchang was implicated in a coup d'état and fled to Wei. His alliance system collapsed. +Qi and Qin made a truce and pursued their own interests. Qi moved south against the state of Song whilst the Qin General Bai Qi pushed back eastward against a Han/Wei alliance, gaining victory at the Battle of Yique. + + +In 288, King Zhao of Qin and King Min of Qi took the title di (帝 'emperor'), of the west and east respectively. They swore a covenant and started planning an attack on Zhao. + + +Su Dai and the second vertical alliance + +In 287 BC the strategist Su Dai, younger brother of Su Qin and possibly an agent of Yan, persuaded King Min that the Zhao war would only benefit Qin. King Min agreed and formed a 'vertical' alliance with the other states against Qin. Qin backed off, abandoned the presumptuous title of "Di", and restored territory to Wei and Zhao. In 286 Qi annexed the state of Song. + + +The second horizontal alliance and fall of Qi + +In 285 BC, the success of Qi had frightened the other states. Under the leadership of Lord Mengchang, who was exiled in Wei, Qin, Zhao, Wei and Yan formed an alliance. Yan had normally been a relatively weak ally of Qi and Qi feared little from this quarter. Yan's onslaught under general Yue Yi came as a devastating surprise. Simultaneously, the other allies attacked from the west. Chu declared itself an ally of Qi but contented itself with annexing some territory to its north. Qi's armies were destroyed while the territory of Qi was reduced to the two cities of Ju and Jimo. King Min himself was later captured and executed by his own followers. + + +King Min was succeeded by King Xiang in 283 BC. His general Tian Dan was eventually able to restore much of Qi's territory, but it never regained the influence it had under King Min. + + +Qin and Zhao expansion + +In 278 BC, the Qin general Bai Qi attacked from Qin's new territory in Sichuan to the west of Chu. The capital of Ying was captured and Chu's western lands on the Han River were lost. The effect was to shift Chu significantly to the east. + + +After Chu was defeated in 278, the remaining great powers were Qin in the west and Zhao in the north-center. There was little room for diplomatic maneuver and matters were decided by wars. Zhao had been much strengthened by King Wuling of Zhao (325–299). In 307 he enlarged his cavalry by copying the northern nomads. In 306 he took more land in the northern Shanxi plateau. In 305 he defeated the north-eastern border state of Zhongshan. In 304 he pushed far to the north-west and occupied the east–west section of the Yellow River in the north of the Ordos Loop. King Huiwen of Zhao (298–266) chose able servants and expanded against the weakened Qi and Wei. In 296 his general Lian Po defeated two Qin armies. + + +In 269 BC Fan Sui became chief advisor to Qin. He advocated authoritarian reforms, irrevocable expansion and an alliance with distant states to attack nearby states (the twenty-third of the Thirty-Six Stratagems). His maxim "attack not only the territory, but also the people" enunciated a policy of mass slaughter that became increasingly frequent. + + +Qin-Zhao wars (282–257 BC) + +In 265 King Zhaoxiang of Qin made the first move by attacking the weak state of Han which held the Yellow River gateway into Qin. He moved north-east across Wei territory to cut off the Han exclave of Shangdang north of Luoyang and south of Zhao. The Han king agreed to surrender Shangdang, but the local governor refused and presented it to King Xiaocheng of Zhao. Zhao sent out Lian Po who based his armies at Changping and Qin sent out general Wang He. Lian Po was too wise to risk a decisive battle with the Qin army and remained inside his fortifications. Qin could not break through and the armies were locked in stalemate for three years. The Zhao king decided that Lian Po was not aggressive enough and sent out Zhao Kuo who promised a decisive battle. At the same time Qin secretly replaced Wang He with the notoriously violent Bai Qi. When Zhao Kuo left his fortifications, Bai Qi used a Cannae maneuver, falling back in the center and surrounding the Zhao army from the sides. After being surrounded for 46 days, the starving Zhao troops surrendered in September 260 BC. It is said that Bai Qi had all the prisoners killed and that Zhao lost 400,000 men. + + +Qin was too exhausted to follow up its victory. Some time later it sent an army to besiege the Zhao capital but the army was destroyed when it was attacked from the rear. Zhao survived, but there was no longer a state that could resist Qin on its own. The other states could have survived if they remained united against Qin, but they did not. + + +In 257 BC, Qin army failed to besiege Handan and was defeated by the allied force of Zhao, Wei and Chu during the Battle of Handan. + + +End of Zhou dynasty (256–249 BC) + +The forces of King Zhao of Qin defeated King Nan of Zhou and conquered West Zhou in 256 BC, claiming the Nine Cauldrons and thereby symbolically becoming The Son of Heaven. + + +King Zhao's exceptionally long reign ended in 251 BC. His son King Xiaowen, already an old man, died just three days after his coronation and was succeeded by his son King Zhuangxiang of Qin. The new Qin king proceeded to conquer East Zhou, seven years after the fall of West Zhou. Thus the 800-year Zhou dynasty, nominally China's longest-ruling regime, finally came to an end. + + +Sima Qian contradicts himself regarding the ultimate fate of the East Zhou court. Chapter 4 (The Annals of Zhou) concludes with the sentence "thus the sacrifices of Zhou ended", but in the following chapter 5 (The Annals of Qin) we learn that "Qin did not prohibit their sacrifices; the Lord of Zhou was allotted a patch of land in Yangren where he could continue his ancestral sacrifices". + + +Qin unites China (247–221 BC) + +King Zhuangxiang of Qin ruled for only three years. He was succeeded by his son Zheng, who unlike the two elderly kings that preceded him was only 13 years old at his coronation. As an adult, Zheng became a brilliant commander who, in the span of just nine years, unified China. + + +Conquest of Han + +In 230 BC, Qin conquered Han. Han, the weakest of the Seven Warring States, was adjacent to the much stronger Qin, and had suffered continuous assaults by Qin in earlier years of the Warring States period. This went on until Emperor Qin Shi Huang sent general Wang Jian to attack Zhao. King An of Han, frightened by the thought that Han would be the next target of the Qin state, immediately sent diplomats to surrender the entire kingdom without a fight, saving the Han populace from the terrible potential consequences of an unsuccessful resistance. + + +Conquest of Wei + +In 225 BC, Qin conquered Wei. The Qin army led a direct invasion into Wei by besieging its capital Daliang but soon realized that the city walls were too tough to break into. They devised a new strategy in which they utilized the power of a local river that was linked to the Yellow River. The river was used to flood the city's walls, causing massive devastation to the city. Upon realizing the situation, King Jia of Wei hurriedly came out of the capital and surrendered it to the Qin army in order to avoid further bloodshed of his people. + + +Conquest of Chu + +In 223 BC, Qin conquered Chu. +The first invasion was however an utter disaster when 200,000 Qin troops, led by the general, Li Xin, were defeated by 500,000 Chu troops in the unfamiliar territory of Huaiyang, modern-day northern Jiangsu and Anhui provinces. Xiang Yan, the Chu commander, had lured Qin by allowing a few initial victories, but then counterattacked and burnt two large Qin camps. + + +In 222 BC, Wang Jian was recalled to lead a second military invasion with 600,000 men against the Chu state. High in morale after their victory in the previous year, the Chu forces were content to sit back and defend against what they expected to be a siege of Chu. However, Wang Jian decided to weaken Chu's resolve and tricked the Chu army by appearing to be idle in his fortifications whilst secretly training his troops to fight in Chu territory. After a year, the Chu defenders decided to disband due to apparent lack of action from the Qin. Wang Jian invaded at that point, with full force, and overran Huaiyang and the remaining Chu forces. Chu lost the initiative and could only sustain local guerrilla-style resistance until it too was fully conquered with the destruction of Shouchun and the death of its last leader, Lord Changping, in 223 BC. At their peak, the combined armies of Chu and Qin are estimated to have ranged from hundreds of thousands to a million soldiers, more than those involved in the campaign of Changping between Qin and Zhao 35 years earlier. + + +Conquest of Zhao and Yan + +In 222 BC, Qin conquered Zhao and Yan. +After the conquest of Zhao, the Qin army turned its attention towards Yan. Realizing the danger and gravity of this situation, Crown Prince Dan of Yan had sent Jing Ke to assassinate King Zheng of Qin, but this failure only helped to fuel the rage and determination of the Qin king, and he increased the number of troops to conquer the Yan state. + + +Conquest of Qi + +In 221 BC, Qin conquered Qi, the final unconquered state. It had not previously contributed or helped other states when Qin was conquering them. As soon as Qin's intention to invade it became clear, Qi swiftly surrendered all its cities, completing the unification of China and ushering in the Qin dynasty. The last Qi king lived out his days in exile in Gong and was not given a posthumous name, therefore he is known to posterity by his personal name Jian. + + +Aftermath + +The Qin king Ying Zheng declared himself as Qin Shi Huangdi, "The first Sovereign Emperor of Qin". + + +In the rule of the Qin state, the union was based solely on military power. The feudal holdings were abolished, and noble families were forced to live in the capital city Xianyang, in order to be supervised. A national road (as well as greater use of canals) allowed for faster and easier deployment and supply of the army. The peasants were given a wider range of land rights, although they were subject to taxation, creating a large amount of revenue to the state. + + +Military theory and practice + +Increasing scale of warfare + +The chariot remained a major factor in Chinese warfare long after it went out of fashion in the Middle East. Near the beginning of the Warring States period there is a shift from chariots to massed infantry, possibly associated with the invention of the crossbow. This had two major effects. First it led the dukes to weaken their chariot-riding nobility so they could get direct access to the peasantry who could be drafted as infantry. This change was associated with the shift from aristocratic to bureaucratic government. Second, it led to a massive increase in the scale of warfare. When the Zhou overthrew the Shang at the Battle of Muye they used 45,000 troops and 300 chariots. For the Warring States period the following figures for the military strengths of various states are reported: + + +For major battles, the following figures are reported: + + +Many scholars think these numbers are exaggerated (records are inadequate, they are much larger than those from similar societies, soldiers were paid by the number of enemies they killed and the Han dynasty had an interest in exaggerating the bloodiness of the age before China was unified). Regardless of exaggeration, it seems clear that warfare had become excessive during this period. The bloodshed and misery of the Warring States period goes a long way in explaining China's traditional and current preference for a united throne. + + +Military developments + +The Warring States period saw the introduction of many innovations to the art of warfare in China, such as the use of iron and of cavalry. + + +Warfare in the Warring States period evolved considerably from the Spring and Autumn period, as most armies made use of infantry and cavalry in battles, and the use of chariots became less widespread. The use of massed infantry made warfare bloodier and reduced the importance of the aristocracy, which in turn made the kings more despotic. From this period onward, as the various states competed with each other by mobilizing their armies to war, nobles in China belonged to the literate class, rather than to the warrior class as had previously been the case. + + +The various states fielded massive armies of infantry, cavalry, and chariots. Complex logistical systems maintained by efficient government bureaucracies were needed to supply, train, and control such large forces. The size of the armies ranged from tens of thousands to several hundred thousand men. +Iron weapons became more widespread and began to replace bronze. Most armour and weapons of this period were made from iron. + + +The first official native Chinese cavalry unit was formed in 307 BC during the military reforms of King Wuling of Zhao, who advocated 'nomadic dress and horse archery'. But the war chariot still retained its prestige and importance, despite the tactical superiority of cavalry. + + +The crossbow was the preferred long-range weapon of this period, due to several reasons. The crossbow could be mass-produced easily, and mass training of crossbowmen was possible. These qualities made it a powerful weapon against the enemy. + + +Infantrymen deployed a variety of weapons, but the most popular was the dagger-axe. The dagger-axe came in various lengths, from 9 to 18 feet; the weapon consisted of a thrusting spear with a slashing blade appended to it. Dagger-axes were an extremely popular weapon in various kingdoms, especially for the Qin, who produced 18-foot-long pike-like weapons. + + +The Qiang battle spear was named as the king 'wang' of all ancient weapons. It had the biggest impact on the battlefield and was quite difficult to master. The second important weapon of that era was the double-edged battle sword Jian. The fighting methods of using the Qiang spear and Jian sword were very different from what we see in movies or re-enactment shows today. Professional warriors of that era used the military concepts of "Master" Sun Tzu and created several successful "Ge Dou" martial schools. + + +Military thought + +The Warring States was a great period for military strategy; of the Seven Military Classics of China, four were written during this period: + + +Culture and society + +The Warring States period was an era of warfare in ancient China, as well as bureaucratic and military reforms and consolidation; the major states, ruling over large territories, quickly sought to consolidate their powers, leading to the final erosion of the Zhou court's prestige. As a sign of this shift, the rulers of all the major states (except for Chu, which had claimed kingly title much earlier) abandoned their former feudal titles for the title of 王, or King, claiming equality with the rulers of the Zhou. + + +At the same time, the constant conflict and need for innovative social and political models led to the development of many philosophical doctrines, later known as the Hundred Schools of Thought. The most notable schools of thought include Mohism (expounded by Mozi), Confucianism (represented by Mencius and Xunzi), Legalism (represented by Shang Yang, Shen Buhai, Shen Dao and Han Fei) and Taoism (represented by Zhuangzi and Lao Tzu). + + +The many states that were competing between each other attempted to display their power not only militarily but in their courts and in state philosophy. Many differing rulers adopted the differing philosophies to their own advantage or that of their kingdom. + + +Mencius attempted to instate Confucianism as a state philosophy, proposing that through the governing of moral principles like benevolence and righteousness, the state would win popular support from one state and those neighboring, eliminating the need of a war altogether. Mencius had attempted to convince King Hui of Liang, although was unsuccessful since the king saw no advantage in the period of wars. + + +Mohism was developed by Mozi (468–376 BC) and it provided a unified moral and political philosophy based on impartiality and benevolence. Mohists had the belief that people change depending on environments around. The same was applied to rulers, which is why one must be cautious of foreign influences. Mozi was very much against warfare, although he was a great tactician in defense. He defended the small state of Song from many attempts of the Chu state. + + +Taoism was advocated by Laozi, and believed that human nature was good and can achieve perfection by returning to its original state. It believed that like a baby, humans are simple and innocent although with development of civilizations it lost its innocence only to be replaced by fraud and greed. Contrarily to other schools, it did not want to gain influence in the offices of states and Laozi even refused to be the minister of the state of Chu. + + +Legalism created by Shang Yang in 338 BC, rejected all notions of religion and practices, and believed a nation should be governed by strict law. Not only were severe punishments applied, but they would be grouped with the families and made mutually responsible for criminal act. It proposed radical reforms, and established a society based on solid ranks. Peasants were encouraged to practice agriculture as occupation, and military performance was rewarded. Laws were also applied to all ranks with no exception; even the king was not above punishment. The philosophy was adapted by the Qin state and it created it into an organized, centralized state with a bureaucracy chosen on the basis of merit. +This period is most famous for the establishment of complex bureaucracies and centralized governments, as well as a clear legal system. The developments in political and military organization were the basis of the power of the Qin state, which conquered the other states and unified them under the Qin dynasty in 221 BC. + + +Nobles, bureaucrats and reformers + +The phenomenon of intensive warfare, based on mass formations of infantry rather than the traditional chariots, was one major trend which led to the creation of strong central bureaucracies in each of the major states. At the same time, the process of secondary feudalism which permeated the Spring and Autumn period, and led to such events as the partition of Jin and the usurpation of Qi by the Tian clan, was eventually reversed by the same process of bureaucratisation. + + +Under the demands of warfare, the states adopted bureaucratic reforms in the Warring States period. Wei adopted these in 445 BC, Zhao in 403 BC, Chu in 390 BC, Han in 355 BC, Qi in 357 BC and Qin in 350 BC. Power was centralised by curbing the landed aristocrats and sinecures and creating a new hierarchy based on meritorious service to the state, which were drawn from the lower rungs of society. Systematic auditing and reporting systems, and fixed salaries for officials were created. + + +The reforms of Shang Yang in Qin, and of Wu Qi in Chu, both centred on increased centralisation, the suppression of the nobility, and a vastly increased scope of government based on Legalist ideals, which were necessary to mobilise the large armies of the period. + + +Sophisticated arithmetic + +A bundle of 21 bamboo slips from the Tsinghua collection dated to 305 BC are the world's earliest example of a two digit decimal multiplication table, indicating that sophisticated commercial arithmetic was already established during this period. + + +Rod numerals were used to represent both negative and positive integers, and rational numbers, a true positional number system, with a blank for zero dating back to the Warring States period. + + +The nine linked-rings puzzle, an advanced puzzle device which requires mathematical analysis to solve, was invented during the period. + + +Literature + +An important literary achievement of the Warring States period is the Zuo Commentary on the Spring and Autumn Annals, which summarizes the preceding Spring and Autumn period. The less famous work Guoyu is thought to be by the same author. + + +Many sayings of Spring and Autumn philosophers, which had previously been circulated orally, were put into writing in the Warring States. These include the Analects and The Art of War. + + +Economic developments + +The Warring States period saw the proliferation of iron working in China, replacing bronze as the dominant type of metal used in warfare. Areas such as Shu (present-day Sichuan) and Yue (present-day Zhejiang) were also brought into the Chinese cultural sphere during this time. Trade also became important, and some merchants had considerable power in politics, the most prominent of which was Lü Buwei, who rose to become Chancellor of Qin and was a key supporter of the eventual Qin Shihuang. + + +At the same time, the increased resources of consolidated, bureaucratic states, coupled with the logistical needs of mass levies and large-scale warfare, led to the proliferation of economic projects such as large-scale waterworks. Major examples of such waterworks include the Dujiangyan Irrigation System, which controlled the Min River in Sichuan and turned the former backwater region into a major Qin logistical base, and the Zhengguo Canal which irrigated large areas of land in the Guanzhong Plain, again increasing Qin's agricultural output. + + +The Guanzi is considered one of the most foundational texts of the developing political economy in the Warring States period. It addresses principles of price regulation in the context of effectively dealing with commodities that are "light" (connoting a commodity which is unimportant, non-essential, or inexpensive) or "heavy" (a commodity which is important, essential, or expensive) and how whether a commodity is "light" or "heavy" is understood in relation to other commodities. + + +In summary:""" + # fmt: on + + input_ids = tokenizer(LONG_TEXT, return_tensors="pt").input_ids.to(torch_device) + generated_ids = model.generate(input_ids, max_new_tokens=50)[:, input_ids.shape[1] :] + + torch.testing.assert_close(generated_ids.cpu(), torch.tensor([[ 279, 467, 19859, 4180, 4168, 572, 264, 4168, 315, 2244, 2297, 304, 5616, 11, 504, 279, 8606, 94350, 1849, 311, 279, 79395, 1584, 11, 504, 279, 8606, 6277, 311, 279, 6277, 11, 504, 279, 8606, 8584, 311, 279, 6955, 11, 323, 504, 279, 8606, 7674, 311, 279, 12752, 13, 576]], dtype=torch.long)) # fmt: skip + self.assertEqual( + tokenizer.decode(generated_ids[0]), + """ the Warring States period was a period of great change in China, from the traditional feudal system to the bureaucratic state, from the traditional military to the military, from the traditional economy to the economic, and from the traditional culture to the cultural. The""", + ) + model.config._attn_implementation = "eager" + new_generated_ids = model.generate(input_ids, max_new_tokens=50)[:, input_ids.shape[1] :] + with self.subTest("Eager matches sdpa"): + torch.testing.assert_close(generated_ids, new_generated_ids, rtol=1e-4, atol=1e-4) + + model.config._attn_implementation = "flex_attention" + new_generated_ids = model.generate(input_ids, max_new_tokens=50)[:, input_ids.shape[1] :] + with self.subTest("Eager matches Flex attention"): + torch.testing.assert_close(generated_ids, new_generated_ids, rtol=1e-4, atol=1e-4) + + model.config._attn_implementation = "flash_attention_2" + new_generated_ids = model.generate(input_ids, max_new_tokens=50)[:, input_ids.shape[1] :] + with self.subTest("Eager matches flash attention"): + torch.testing.assert_close(generated_ids, new_generated_ids, rtol=1e-4, atol=1e-4) diff --git a/tests/models/qwen3_moe/__init__.py b/tests/models/qwen3_moe/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/models/qwen3_moe/test_modeling_qwen3_moe.py b/tests/models/qwen3_moe/test_modeling_qwen3_moe.py new file mode 100644 index 00000000000..061e2ea80be --- /dev/null +++ b/tests/models/qwen3_moe/test_modeling_qwen3_moe.py @@ -0,0 +1,631 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Qwen3MoE model.""" + +import gc +import unittest + +import pytest + +from transformers import AutoTokenizer, Qwen3MoeConfig, is_torch_available, set_seed +from transformers.testing_utils import ( + backend_empty_cache, + require_bitsandbytes, + require_flash_attn, + require_torch, + require_torch_gpu, + require_torch_sdpa, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + Qwen3MoeForCausalLM, + Qwen3MoeForQuestionAnswering, + Qwen3MoeForSequenceClassification, + Qwen3MoeForTokenClassification, + Qwen3MoeModel, + ) + + +class Qwen3MoeModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + max_window_layers=3, + use_sliding_window=True, + sliding_window=50, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=16, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + expert_interval=1, + moe_intermediate_size=12, + num_experts_per_tok=2, + num_experts=8, + norm_topk_prob=False, + output_router_logits=False, + router_aux_loss_coef=0.001, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + pad_token_id=0, + bos_token_id=1, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.max_window_layers = max_window_layers + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.scope = scope + self.expert_interval = expert_interval + self.moe_intermediate_size = moe_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.norm_topk_prob = norm_topk_prob + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + return Qwen3MoeConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + max_window_layers=self.max_window_layers, + use_sliding_window=self.use_sliding_window, + sliding_window=self.sliding_window, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + head_dim=self.head_dim, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + expert_interval=self.expert_interval, + moe_intermediate_size=self.moe_intermediate_size, + num_experts_per_tok=self.num_experts_per_tok, + num_experts=self.num_experts, + norm_topk_prob=self.norm_topk_prob, + output_router_logits=self.output_router_logits, + router_aux_loss_coef=self.router_aux_loss_coef, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + pad_token_id=self.pad_token_id, + bos_token_id=self.bos_token_id, + ) + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Qwen3Moe + def create_and_check_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = Qwen3MoeModel(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model_as_decoder with Llama->Qwen3Moe + def create_and_check_model_as_decoder( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + model = Qwen3MoeModel(config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + ) + result = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + ) + result = model(input_ids, attention_mask=input_mask) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_for_causal_lm with Llama->Qwen3Moe + def create_and_check_for_causal_lm( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + model = Qwen3MoeForCausalLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_decoder_model_past_large_inputs with Llama->Qwen3Moe + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.is_decoder = True + config.add_cross_attention = True + model = Qwen3MoeForCausalLM(config=config) + model.to(torch_device) + model.eval() + + # first forward pass + outputs = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_hidden_states=True, + )["hidden_states"][0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + )["hidden_states"][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen3Moe +class Qwen3MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = ( + ( + Qwen3MoeModel, + Qwen3MoeForCausalLM, + Qwen3MoeForSequenceClassification, + Qwen3MoeForTokenClassification, + Qwen3MoeForQuestionAnswering, + ) + if is_torch_available() + else () + ) + pipeline_model_mapping = ( + { + "feature-extraction": Qwen3MoeModel, + "text-classification": Qwen3MoeForSequenceClassification, + "token-classification": Qwen3MoeForTokenClassification, + "text-generation": Qwen3MoeForCausalLM, + "zero-shot": Qwen3MoeForSequenceClassification, + "question-answering": Qwen3MoeForQuestionAnswering, + } + if is_torch_available() + else {} + ) + test_headmasking = False + test_pruning = False + fx_compatible = False # Broken by attention refactor cc @Cyrilvallez + + # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 + def is_pipeline_test_to_skip( + self, + pipeline_test_case_name, + config_class, + model_architecture, + tokenizer_name, + image_processor_name, + feature_extractor_name, + processor_name, + ): + return True + + def setUp(self): + self.model_tester = Qwen3MoeModelTester(self) + self.config_tester = ConfigTester(self, config_class=Qwen3MoeConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_various_embeddings(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + for type in ["absolute", "relative_key", "relative_key_query"]: + config_and_inputs[0].position_embedding_type = type + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_torch_fx_output_loss(self): + super().test_torch_fx_output_loss() + + def test_Qwen3Moe_sequence_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + print(config) + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + model = Qwen3MoeForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + def test_Qwen3Moe_sequence_classification_model_for_single_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "single_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + model = Qwen3MoeForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + def test_Qwen3Moe_sequence_classification_model_for_multi_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "multi_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor( + [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size + ).to(torch.float) + model = Qwen3MoeForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Qwen3Moe,llama->Qwen3Moe + def test_Qwen3Moe_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + model = Qwen3MoeForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=token_labels) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + + @unittest.skip(reason="Qwen3Moe buffers include complex numbers, which breaks this test") + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip(reason="Qwen3Moe uses GQA on all models so the KV cache is a non standard format") + def test_past_key_values_format(self): + pass + + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_inference_equivalence_right_padding(self): + self.skipTest(reason="Qwen3Moe flash attention does not support right padding") + + # Ignore copy + def test_load_balancing_loss(self): + r""" + Let's make sure we can actually compute the loss and do a backward on it. + """ + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.num_experts = 8 + config.expert_interval = 2 + config.output_router_logits = True + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + model = Qwen3MoeForCausalLM(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask) + self.assertEqual(result.router_logits[0].shape, (91, config.num_experts)) + torch.testing.assert_close(result.aux_loss.cpu(), torch.tensor(2, dtype=torch.float32), rtol=1e-2, atol=1e-2) + + # First, we make sure that adding padding tokens doesn't change the loss + # loss(input_ids, attention_mask=None) == loss(input_ids + padding, attention_mask=attention_mask_with_padding) + pad_length = 1000 + # Add padding tokens (assume that pad_token_id=1) to input_ids + padding_block = torch.ones(input_ids.shape[0], pad_length, dtype=torch.int32).to(torch_device) + padded_input_ids = torch.cat((padding_block, input_ids), dim=1) # this is to simulate padding to the left + padded_attention_mask = padded_input_ids.ne(1).to(torch_device) + + padded_result = model(padded_input_ids, attention_mask=padded_attention_mask) + torch.testing.assert_close(result.aux_loss.cpu(), padded_result.aux_loss.cpu(), rtol=1e-4, atol=1e-4) + + # We make sure that the loss of includding padding tokens != the loss without padding tokens + # if attention_mask=None --> we don't exclude padding tokens + include_padding_result = model(padded_input_ids, attention_mask=None) + + # This is to mimic torch.testing.assert_not_close + self.assertNotAlmostEqual(include_padding_result.aux_loss.item(), result.aux_loss.item()) + + +@require_torch +class Qwen3MoeIntegrationTest(unittest.TestCase): + @slow + def test_model_15b_a2b_logits(self): + input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] + model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-15B-A2B-Base", device_map="auto") + input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) + with torch.no_grad(): + out = model(input_ids).logits.float().cpu() + # Expected mean on dim = -1 + EXPECTED_MEAN = torch.tensor([[-1.1184, 1.1356, 9.2112, 8.0254, 5.1663, 7.9287, 8.9245, 10.0671]]) + torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2) + # slicing logits[0, 0, 0:30] + EXPECTED_SLICE = torch.tensor([7.5938, 2.6094, 4.0312, 4.0938, 2.5156, 2.7812, 2.9688, 1.5547, 1.3984, 2.2344, 3.0156, 3.1562, 1.1953, 3.2500, 1.0938, 8.4375, 9.5625, 9.0625, 7.5625, 7.5625, 7.9062, 7.2188, 7.0312, 6.9375, 8.0625, 1.7266, 0.9141, 3.7969, 5.3438, 3.9844]) # fmt: skip + print(out[0, 0, :30]) + torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-4, atol=1e-4) + + del model + backend_empty_cache(torch_device) + gc.collect() + + @slow + def test_model_15b_a2b_generation(self): + EXPECTED_TEXT_COMPLETION = ( + """To be or not to be, that is the question. Whether 'tis nobler in the mind to suffer the sl""" + ) + prompt = "To be or not to" + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-15B-A2B-Base", use_fast=False) + model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-15B-A2B-Base", device_map="auto") + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device) + + # greedy generation outputs + generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + del model + backend_empty_cache(torch_device) + gc.collect() + + @require_bitsandbytes + @slow + @require_flash_attn + @pytest.mark.flash_attn_test + def test_model_15b_a2b_long_prompt(self): + EXPECTED_OUTPUT_TOKEN_IDS = [306, 338] + # An input with 4097 tokens that is above the size of the sliding window + input_ids = [1] + [306, 338] * 2048 + model = Qwen3MoeForCausalLM.from_pretrained( + "Qwen/Qwen3-15B-A2B-Base", + device_map="auto", + load_in_4bit=True, + attn_implementation="flash_attention_2", + ) + input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) + generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + + # Assisted generation + assistant_model = model + assistant_model.generation_config.num_assistant_tokens = 2 + assistant_model.generation_config.num_assistant_tokens_schedule = "constant" + generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + + del assistant_model + del model + backend_empty_cache(torch_device) + gc.collect() + + @slow + @require_torch_sdpa + def test_model_15b_a2b_long_prompt_sdpa(self): + EXPECTED_OUTPUT_TOKEN_IDS = [306, 338] + # An input with 4097 tokens that is above the size of the sliding window + input_ids = [1] + [306, 338] * 2048 + model = Qwen3MoeForCausalLM.from_pretrained( + "Qwen/Qwen3-15B-A2B-Base", + device_map="auto", + attn_implementation="sdpa", + ) + input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) + generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + + # Assisted generation + assistant_model = model + assistant_model.generation_config.num_assistant_tokens = 2 + assistant_model.generation_config.num_assistant_tokens_schedule = "constant" + generated_ids = assistant_model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + + del assistant_model + + backend_empty_cache(torch_device) + gc.collect() + + EXPECTED_TEXT_COMPLETION = ( + """To be or not to be, that is the question. Whether 'tis nobler in the mind to suffer the sl""" + ) + prompt = "To be or not to" + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-15B-A2B-Base", use_fast=False) + + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device) + + # greedy generation outputs + generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + @slow + def test_speculative_generation(self): + EXPECTED_TEXT_COMPLETION = ( + "To be or not to be, that is the question: whether 'tis nobler in the mind to suffer the sl" + ) + prompt = "To be or not to" + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-15B-A2B-Base", use_fast=False) + model = Qwen3MoeForCausalLM.from_pretrained( + "Qwen/Qwen3-15B-A2B-Base", device_map="auto", torch_dtype=torch.float16 + ) + assistant_model = Qwen3MoeForCausalLM.from_pretrained( + "Qwen/Qwen3-15B-A2B-Base", device_map="auto", torch_dtype=torch.float16 + ) + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device) + + # greedy generation outputs + set_seed(0) + generated_ids = model.generate( + input_ids, max_new_tokens=20, do_sample=True, temperature=0.3, assistant_model=assistant_model + ) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + del model + backend_empty_cache(torch_device) + gc.collect() diff --git a/tests/repo_utils/modular/test_conversion_order.py b/tests/repo_utils/modular/test_conversion_order.py index 643b3714a76..7f44684b7b5 100644 --- a/tests/repo_utils/modular/test_conversion_order.py +++ b/tests/repo_utils/modular/test_conversion_order.py @@ -22,6 +22,8 @@ FILES_TO_PARSE = [ os.path.join(MODEL_ROOT, "olmo", "modular_olmo.py"), os.path.join(MODEL_ROOT, "rt_detr", "modular_rt_detr.py"), os.path.join(MODEL_ROOT, "qwen2", "modular_qwen2.py"), + os.path.join(MODEL_ROOT, "qwen3", "modular_qwen3.py"), + os.path.join(MODEL_ROOT, "qwen3", "modular_qwen3_moe.py"), os.path.join(MODEL_ROOT, "llava_next_video", "modular_llava_next_video.py"), os.path.join(MODEL_ROOT, "cohere2", "modular_cohere2.py"), os.path.join(MODEL_ROOT, "modernbert", "modular_modernbert.py"), diff --git a/utils/check_config_docstrings.py b/utils/check_config_docstrings.py index 341fc42b9c6..a22abd23881 100644 --- a/utils/check_config_docstrings.py +++ b/utils/check_config_docstrings.py @@ -47,6 +47,7 @@ CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = { "LlamaConfig", "GraniteConfig", "GraniteMoeConfig", + "Qwen3MoeConfig", }