diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index d12644ae9d7..12e4224070f 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -555,6 +555,8 @@ title: MegatronBERT - local: model_doc/megatron_gpt2 title: MegatronGPT2 + - local: model_doc/minimax + title: MiniMax - local: model_doc/mistral title: Mistral - local: model_doc/mixtral diff --git a/docs/source/en/model_doc/minimax.md b/docs/source/en/model_doc/minimax.md new file mode 100644 index 00000000000..a8c5ee1b236 --- /dev/null +++ b/docs/source/en/model_doc/minimax.md @@ -0,0 +1,189 @@ + + +# MiniMax + +## Overview + +The DepthPro model was proposed in [MiniMax-01: Scaling Foundation Models with Lightning Attention](https://arxiv.org/abs/2501.08313) by MiniMax, Aonian Li, Bangwei Gong, Bo Yang, Boji Shan, Chang Liu, Cheng Zhu, Chunhao Zhang, Congchao Guo, Da Chen, Dong Li, Enwei Jiao, Gengxin Li, Guojun Zhang, Haohai Sun, Houze Dong, Jiadai Zhu, Jiaqi Zhuang, Jiayuan Song, Jin Zhu, Jingtao Han, Jingyang Li, Junbin Xie, Junhao Xu, Junjie Yan, Kaishun Zhang, Kecheng Xiao, Kexi Kang, Le Han, Leyang Wang, Lianfei Yu, Liheng Feng, Lin Zheng, Linbo Chai, Long Xing, Meizhi Ju, Mingyuan Chi, Mozhi Zhang, Peikai Huang, Pengcheng Niu, Pengfei Li, Pengyu Zhao, Qi Yang, Qidi Xu, Qiexiang Wang, Qin Wang, Qiuhui Li, Ruitao Leng, Shengmin Shi, Shuqi Yu, Sichen Li, Songquan Zhu, Tao Huang, Tianrun Liang, Weigao Sun, Weixuan Sun, Weiyu Cheng, Wenkai Li, Xiangjun Song, Xiao Su, Xiaodong Han, Xinjie Zhang, Xinzhu Hou, Xu Min, Xun Zou, Xuyang Shen, Yan Gong, Yingjie Zhu, Yipeng Zhou, Yiran Zhong, Yongyi Hu, Yuanxiang Fan, Yue Yu, Yufeng Yang, Yuhao Li, Yunan Huang, Yunji Li, Yunpeng Huang, Yunzhi Xu, Yuxin Mao, Zehan Li, Zekang Li, Zewei Tao, Zewen Ying, Zhaoyang Cong, Zhen Qin, Zhenhua Fan, Zhihang Yu, Zhuo Jiang, Zijia Wu. + +The abstract from the paper is the following: + +*We introduce MiniMax-01 series, including MiniMax-Text-01 and MiniMax-VL-01, which are comparable to top-tier models while offering superior capabilities in processing longer contexts. The core lies in lightning attention and its efficient scaling. To maximize computational capacity, we integrate it with Mixture of Experts (MoE), creating a model with 32 experts and 456 billion total parameters, of which 45.9 billion are activated for each token. We develop an optimized parallel strategy and highly efficient computation-communication overlap techniques for MoE and lightning attention. This approach enables us to conduct efficient training and inference on models with hundreds of billions of parameters across contexts spanning millions of tokens. The context window of MiniMax-Text-01 can reach up to 1 million tokens during training and extrapolate to 4 million tokens during inference at an affordable cost. Our vision-language model, MiniMax-VL-01 is built through continued training with 512 billion vision-language tokens. Experiments on both standard and in-house benchmarks show that our models match the performance of state-of-the-art models like GPT-4o and Claude-3.5-Sonnet while offering 20-32 times longer context window.* + +### Architectural details + +MiniMax is a powerful language model with 456 billion total parameters, of which 45.9 billion are activated per token. To better unlock the long context capabilities of the model, MiniMax adopts a hybrid architecture that combines Lightning Attention, Softmax Attention and Mixture-of-Experts (MoE). Leveraging advanced parallel strategies and innovative compute-communication overlap methodsβ€”such as Linear Attention Sequence Parallelism Plus (LASP+), varlen ring attention, Expert Tensor Parallel (ETP), etc., MiniMax's training context length is extended to 1 million tokens, and it can handle a context of up to 4 million tokens during the inference. On various academic benchmarks, MiniMax also demonstrates the performance of a top-tier model. + +The architecture of MiniMax is briefly described as follows: + +- Total Parameters: 456B +- Activated Parameters per Token: 45.9B +- Number Layers: 80 +- Hybrid Attention: a softmax attention is positioned after every 7 lightning attention. + - Number of attention heads: 64 + - Attention head dimension: 128 +- Mixture of Experts: + - Number of experts: 32 + - Expert hidden dimension: 9216 + - Top-2 routing strategy +- Positional Encoding: Rotary Position Embedding (RoPE) applied to half of the attention head dimension with a base frequency of 10,000,000 +- Hidden Size: 6144 +- Vocab Size: 200,064 + +For more details refer to the [release blog post](https://www.minimaxi.com/en/news/minimax-01-series-2). + +### License + +`MiniMax` is released under the MINIMAX MODEL LICENSE AGREEMENT. + +## Usage tips + +The pre-trained model can be used as follows: + +```python +>>> from transformers import AutoModelForCausalLM, AutoTokenizer + +>>> model = AutoModelForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf", device_map="auto") +>>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf") + +>>> messages = [ +... {"role": "user", "content": "What is your favourite condiment?"}, +... {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"}, +... {"role": "user", "content": "Do you have mayonnaise recipes?"} +... ] + +>>> model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda") + +>>> generated_ids = model.generate(model_inputs, max_new_tokens=100, do_sample=True) +>>> tokenizer.batch_decode(generated_ids)[0] +"Mayonnaise can be made as follows: (...)" +``` + +As can be seen, the instruction-tuned model requires a [chat template](../chat_templating) to be applied to make sure the inputs are prepared in the right format. + +## Speeding up MiniMax by using Flash Attention + +The code snippets above showcase inference without any optimization tricks. However, one can drastically speed up the model by leveraging [Flash Attention](../perf_train_gpu_one#flash-attention-2), which is a faster implementation of the attention mechanism used inside the model. + +First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature. + +```bash +pip install -U flash-attn --no-build-isolation +``` + +Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of the [flash attention repository](https://github.com/Dao-AILab/flash-attention). Make also sure to load your model in half-precision (e.g. `torch.float16`) + +To load and run a model using Flash Attention-2, refer to the snippet below: + +```python +>>> import torch +>>> from transformers import AutoModelForCausalLM, AutoTokenizer + +>>> model = AutoModelForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf", torch_dtype=torch.float16, attn_implementation="flash_attention_2", device_map="auto") +>>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf") + +>>> prompt = "My favourite condiment is" + +>>> model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda") +>>> model.to(device) + +>>> generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True) +>>> tokenizer.batch_decode(generated_ids)[0] +"The expected output" +``` + +### Sliding window Attention + +The current implementation supports the sliding window attention mechanism and memory efficient cache management. +To enable sliding window attention, just make sure to have a `flash-attn` version that is compatible with sliding window attention (`>=2.3.0`). + +The Flash Attention-2 model uses also a more memory efficient cache slicing mechanism - as recommended per the official implementation of Mistral model that use rolling cache mechanism we keep the cache size fixed (`self.config.sliding_window`), support batched generation only for `padding_side="left"` and use the absolute position of the current token to compute the positional embedding. + +## Shrinking down MiniMax using quantization + +As the MiniMax model has 456 billion parameters, that would require about 912GB of GPU RAM in half precision (float16), since each parameter is stored in 2 bytes. However, one can shrink down the size of the model using [quantization](../quantization.md). If the model is quantized to 4 bits (or half a byte per parameter), about 228 GB of RAM is required. + +Quantizing a model is as simple as passing a `quantization_config` to the model. Below, we'll leverage the bitsandbytes quantization library (but refer to [this page](../quantization.md) for alternative quantization methods): + +```python +>>> import torch +>>> from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig + +>>> # specify how to quantize the model +>>> quantization_config = BitsAndBytesConfig( +... load_in_4bit=True, +... bnb_4bit_quant_type="nf4", +... bnb_4bit_compute_dtype="torch.float16", +... ) + +>>> model = AutoModelForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf", quantization_config=True, device_map="auto") +>>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf") + +>>> prompt = "My favourite condiment is" + +>>> messages = [ +... {"role": "user", "content": "What is your favourite condiment?"}, +... {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"}, +... {"role": "user", "content": "Do you have mayonnaise recipes?"} +... ] + +>>> model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda") + +>>> generated_ids = model.generate(model_inputs, max_new_tokens=100, do_sample=True) +>>> tokenizer.batch_decode(generated_ids)[0] +"The expected output" +``` + +This model was contributed by [geetu040](https://github.com/geetu040). +The original code can be found [here](https://huggingface.co/MiniMaxAI/MiniMax-Text-01-hf/blob/main/modeling_minimax.py). + +## Resources + +A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with MiniMax. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource. + + + +- The [Alignment Handbook](https://github.com/huggingface/alignment-handbook) by Hugging Face includes scripts and recipes to perform supervised fine-tuning (SFT) and direct preference optimization with Mistral-7B. This includes scripts for full fine-tuning, QLoRa on a single GPU as well as multi-GPU fine-tuning. +- [Causal language modeling task guide](../tasks/language_modeling) + +## MiniMaxConfig + +[[autodoc]] MiniMaxConfig + +## MiniMaxModel + +[[autodoc]] MiniMaxModel + - forward + +## MiniMaxForCausalLM + +[[autodoc]] MiniMaxForCausalLM + - forward + +## MiniMaxForSequenceClassification + +[[autodoc]] MiniMaxForSequenceClassification + - forward + +## MiniMaxForTokenClassification + +[[autodoc]] MiniMaxForTokenClassification + - forward + +## MiniMaxForQuestionAnswering +[[autodoc]] MiniMaxForQuestionAnswering + - forward diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 205a7dde8f2..74bf8cca488 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -1218,6 +1218,7 @@ ALLOWED_LAYER_TYPES = ( "full_attention", "sliding_attention", "chunked_attention", + "linear_attention", # used in minimax ) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 9af619c8203..4a549fc2155 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1976,6 +1976,7 @@ class GenerationMixin(ContinuousMixin): and "jamba" not in self.__class__.__name__.lower() and "zamba" not in self.__class__.__name__.lower() and "bamba" not in self.__class__.__name__.lower() + and "minimax" not in self.__class__.__name__.lower() ) def _prepare_cache_for_generation( diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 151eb0844d0..fd4b889fd37 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -185,6 +185,7 @@ if TYPE_CHECKING: from .megatron_gpt2 import * from .mgp_str import * from .mimi import * + from .minimax import * from .mistral import * from .mistral3 import * from .mixtral import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 056516e7318..b6ee5091705 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -211,6 +211,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str]( ("megatron-bert", "MegatronBertConfig"), ("mgp-str", "MgpstrConfig"), ("mimi", "MimiConfig"), + ("minimax", "MiniMaxConfig"), ("mistral", "MistralConfig"), ("mistral3", "Mistral3Config"), ("mixtral", "MixtralConfig"), @@ -586,6 +587,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str]( ("megatron_gpt2", "Megatron-GPT2"), ("mgp-str", "MGP-STR"), ("mimi", "Mimi"), + ("minimax", "MiniMax"), ("mistral", "Mistral"), ("mistral3", "Mistral3"), ("mixtral", "Mixtral"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 8ea758c1d72..e28e978ca8d 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -201,6 +201,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("megatron-bert", "MegatronBertModel"), ("mgp-str", "MgpstrForSceneTextRecognition"), ("mimi", "MimiModel"), + ("minimax", "MiniMaxModel"), ("mistral", "MistralModel"), ("mistral3", "Mistral3Model"), ("mixtral", "MixtralModel"), @@ -594,6 +595,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ("mbart", "MBartForCausalLM"), ("mega", "MegaForCausalLM"), ("megatron-bert", "MegatronBertForCausalLM"), + ("minimax", "MiniMaxForCausalLM"), ("mistral", "MistralForCausalLM"), ("mixtral", "MixtralForCausalLM"), ("mllama", "MllamaForCausalLM"), @@ -1106,6 +1108,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("mbart", "MBartForSequenceClassification"), ("mega", "MegaForSequenceClassification"), ("megatron-bert", "MegatronBertForSequenceClassification"), + ("minimax", "MiniMaxForSequenceClassification"), ("mistral", "MistralForSequenceClassification"), ("mixtral", "MixtralForSequenceClassification"), ("mobilebert", "MobileBertForSequenceClassification"), @@ -1197,6 +1200,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( ("mbart", "MBartForQuestionAnswering"), ("mega", "MegaForQuestionAnswering"), ("megatron-bert", "MegatronBertForQuestionAnswering"), + ("minimax", "MiniMaxForQuestionAnswering"), ("mistral", "MistralForQuestionAnswering"), ("mixtral", "MixtralForQuestionAnswering"), ("mobilebert", "MobileBertForQuestionAnswering"), @@ -1303,6 +1307,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("markuplm", "MarkupLMForTokenClassification"), ("mega", "MegaForTokenClassification"), ("megatron-bert", "MegatronBertForTokenClassification"), + ("minimax", "MiniMaxForTokenClassification"), ("mistral", "MistralForTokenClassification"), ("mixtral", "MixtralForTokenClassification"), ("mobilebert", "MobileBertForTokenClassification"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index f2344daba02..cba0e2e1cde 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -342,6 +342,13 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]]( ("mega", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), ("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("mgp-str", ("MgpstrTokenizer", None)), + ( + "minimax", + ( + "GPT2Tokenizer" if is_sentencepiece_available() else None, + "GPT2TokenizerFast" if is_tokenizers_available() else None, + ), + ), ( "mistral", ( diff --git a/src/transformers/models/minimax/__init__.py b/src/transformers/models/minimax/__init__.py new file mode 100644 index 00000000000..91834eb6a22 --- /dev/null +++ b/src/transformers/models/minimax/__init__.py @@ -0,0 +1,29 @@ +# coding=utf-8 +# Copyright 2025 MiniMaxAI and HuggingFace Inc. teams. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_minimax import * + from .modeling_minimax import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/minimax/configuration_minimax.py b/src/transformers/models/minimax/configuration_minimax.py new file mode 100644 index 00000000000..c0d8611af5c --- /dev/null +++ b/src/transformers/models/minimax/configuration_minimax.py @@ -0,0 +1,230 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/minimax/modular_minimax.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_minimax.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 MiniMaxAI and HuggingFace Inc. teams. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ...configuration_utils import PretrainedConfig, layer_type_validation + + +class MiniMaxConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MiniMaxModel`]. It is used to instantiate an + MiniMax model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the MiniMax. + + [MiniMaxAI/MiniMax-Text-01-hf](https://huggingface.co/MiniMaxAI/MiniMax-Text-01-hf) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the MiniMax model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MiniMaxModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`): + The attention head dimension. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. MiniMax's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `4096`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_experts_per_tok (`int`, *optional*, defaults to 2): + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter + num_local_experts (`int`, *optional*, defaults to 8): + Number of experts per Sparse MLP layer. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabeling this will also + allow the model to output the auxiliary loss. See [here]() for more details + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + router_jitter_noise (`float`, *optional*, defaults to 0.0): + Amount of noise to add to the router. + layer_types (`list`, *optional*): + Attention pattern for each layer. + block_size (`int`, *optional*, defaults to 256): + The length of each attention block, determining how queries, keys, and values + are grouped and processed for intra- and inter-block attention. + full_attn_alpha_factor (`float`, *optional*, defaults to 1): + Weight for residual value in residual connection after normal attention. + full_attn_beta_factor (`float`, *optional*, defaults to 1): + Weight for hidden state value in residual connection after normal attention. + linear_attn_alpha_factor (`float`, *optional*, defaults to 1): + Weight for residual value in residual connection after lightning attention. + linear_attn_beta_factor (`float`, *optional*, defaults to 1): + Weight for hidden state value in residual connection after lightning attention. + mlp_alpha_factor (`float`, *optional*, defaults to 1): + Weight for residual value in residual connection after MLP. + mlp_beta_factor (`float`, *optional*, defaults to 1): + Weight for hidden state value in residual connection after MLP. + + ```python + >>> from transformers import MiniMaxModel, MiniMaxConfig + + >>> # Initializing a MiniMax style configuration + >>> configuration = MiniMaxConfig() + + >>> # Initializing a model from the MiniMax style configuration + >>> model = MiniMaxModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "minimax" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.block_sparse_moe.gate": "colwise_rep", # we need to replicate here to correctly route experts + "layers.*.block_sparse_moe.experts.*.w1": "colwise", + "layers.*.block_sparse_moe.experts.*.w2": "rowwise", + "layers.*.block_sparse_moe.experts.*.w3": "colwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + head_dim=None, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=1e6, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=2, + num_local_experts=8, + output_router_logits=False, + router_aux_loss_coef=0.001, + router_jitter_noise=0.0, + layer_types=None, + block_size=256, + full_attn_alpha_factor=1, + full_attn_beta_factor=1, + linear_attn_alpha_factor=1, + linear_attn_beta_factor=1, + mlp_alpha_factor=1, + mlp_beta_factor=1, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.head_dim = head_dim + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.router_jitter_noise = router_jitter_noise + self.layer_types = layer_types + self.block_size = block_size + self.full_attn_alpha_factor = full_attn_alpha_factor + self.full_attn_beta_factor = full_attn_beta_factor + self.linear_attn_alpha_factor = linear_attn_alpha_factor + self.linear_attn_beta_factor = linear_attn_beta_factor + self.mlp_alpha_factor = mlp_alpha_factor + self.mlp_beta_factor = mlp_beta_factor + + if self.layer_types is None: + self.layer_types = [ + "full_attention" if bool((i + 1) % 2) else "linear_attention" for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types) + + +__all__ = ["MiniMaxConfig"] diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py new file mode 100644 index 00000000000..18d2e4df7d9 --- /dev/null +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -0,0 +1,1259 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/minimax/modular_minimax.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_minimax.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 MiniMaxAI and HuggingFace Inc. teams. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ( + BaseModelOutputWithPast, + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging +from .configuration_minimax import MiniMaxConfig + + +logger = logging.get_logger(__name__) + + +@use_kernel_forward_from_hub("RMSNorm") +class MiniMaxRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MiniMaxRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class MiniMaxCache(DynamicCache): + def __init__(self): + super().__init__() + self.linear_cache: List[torch.Tensor] = [] + + def set_linear_cache(self, layer_idx, linear_cache): + # There may be skipped layers, fill them with empty lists + for _ in range(len(self.linear_cache), layer_idx + 1): + self.linear_cache.append([]) + self.linear_cache[layer_idx] = linear_cache + + def get_linear_cache(self, layer_idx: int): + if layer_idx < len(self): + return self.linear_cache[layer_idx] + return None + + def __len__(self): + return max(super().__len__(), len(self.linear_cache)) + + def __getitem__(self, layer_idx: int): + if layer_idx < len(self.linear_cache) and self.linear_cache[layer_idx] != []: + return (self.linear_cache[layer_idx],) + return super().__getitem__(layer_idx) + + def __iter__(self): + for layer_idx in range(len(self)): + yield self[layer_idx] + + def batch_repeat_interleave(self, repeats: int): + for layer_idx in range(len(self)): + if self.linear_cache[layer_idx] != []: + self.linear_cache[layer_idx] = self.linear_cache[layer_idx].repeat_interleave(repeats, dim=0) + else: + self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) + + def batch_select_indices(self, indices: torch.Tensor): + for layer_idx in range(len(self)): + if self.linear_cache[layer_idx] != []: + self.linear_cache[layer_idx] = self.linear_cache[layer_idx][indices, ...] + else: + self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] + self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] + + def crop(self, max_length: int): + raise RuntimeError("MiniMaxCache doesnot support `crop` method") + + +class MiniMaxLightningAttention(nn.Module): + def __init__(self, config: MiniMaxConfig, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + self.num_attention_heads = config.num_attention_heads + self.num_hidden_layers = config.num_hidden_layers + self.block_size = config.block_size + + self.act_fn = ACT2FN[config.hidden_act] + self.norm = MiniMaxRMSNorm(self.head_dim * self.num_attention_heads) + self.qkv_proj = nn.Linear(config.hidden_size, self.num_attention_heads * self.head_dim * 3, bias=False) + self.out_proj = nn.Linear(self.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.output_gate = nn.Linear(config.hidden_size, self.num_attention_heads * self.head_dim, bias=False) + + slope_rate = self.get_slope_rate() + query_decay, key_decay, diagonal_decay = self.decay_factors(slope_rate) + + self.register_buffer("slope_rate", slope_rate) + self.register_buffer("query_decay", query_decay) + self.register_buffer("key_decay", key_decay) + self.register_buffer("diagonal_decay", diagonal_decay) + + def get_slope_rate(self): + base = 1 / (2 ** (8 / self.num_attention_heads)) + exponent = torch.arange(self.num_attention_heads) + 1 + factor = 1 - self.layer_idx / (self.num_hidden_layers - 1 + 1e-5) + 1e-5 + + rate = base**exponent + rate = rate * factor + rate = rate[:, None, None] + + return rate + + def decay_factors(self, slope_rate): + block_size_range = torch.arange(self.block_size) + 1 + + query_decay = torch.exp(-slope_rate * block_size_range[:, None]) + key_decay = torch.exp(-slope_rate * (self.block_size - block_size_range[:, None])) + + diagonal_decay = block_size_range[:, None] - block_size_range[None, :] + diagonal_decay = diagonal_decay[None, None, :, :] + diagonal_decay = slope_rate * diagonal_decay + diagonal_decay = torch.where(diagonal_decay >= 0, -diagonal_decay, float("-inf")) + diagonal_decay = torch.exp(diagonal_decay) + + return query_decay, key_decay, diagonal_decay + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + batch_size, seq_len, hidden_size = hidden_states.shape + num_blocks = (seq_len + self.block_size - 1) // self.block_size + + qkv_states = self.act_fn(self.qkv_proj(hidden_states)) + qkv_states = qkv_states.reshape(batch_size, seq_len, self.num_attention_heads, 3 * self.head_dim) + + query_states, key_states, value_states = torch.split(qkv_states, self.head_dim, dim=3) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # calculated (K.T @ V) and saved as cache + attn_weights_inter = None + if past_key_value is not None: + attn_weights_inter = past_key_value.get_linear_cache(self.layer_idx) + + if attn_weights_inter is None: + attn_weights_inter = torch.zeros(batch_size, self.num_attention_heads, self.head_dim, self.head_dim).to( + value_states + ) + + # apply attention_mask + if attention_mask is not None: + attention_mask = attention_mask.to(dtype=torch.bool) # Ensure it's a boolean tensor + value_states = value_states.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(-1), 0) + + attn_output = [] + for i in range(num_blocks): + start_idx = i * self.block_size + end_idx = min(start_idx + self.block_size, seq_len) + current_block_size = end_idx - start_idx + + current_query_states = query_states[:, :, start_idx:end_idx] + current_key_states = key_states[:, :, start_idx:end_idx] + current_value_states = value_states[:, :, start_idx:end_idx] + + current_query_decay = self.query_decay[:, :current_block_size] + current_key_decay = self.key_decay[:, -current_block_size:] + current_diagonal_decay = self.diagonal_decay[:, :, :current_block_size, :current_block_size] + block_decay = torch.exp(-self.slope_rate * current_block_size) + + # intra: ( Q @ K.T ) @ V -> QK * V + attn_weights_intra = torch.matmul(current_query_states, current_key_states.transpose(-1, -2)) + attn_output_intra = torch.matmul(attn_weights_intra * current_diagonal_decay, current_value_states) + + # inter: Q @ ( K.T @ V ) -> Q * KV + attn_output_inter = torch.matmul(current_query_states * current_query_decay, attn_weights_inter) + + # final attention output + current_attn_output = attn_output_inter + attn_output_intra + attn_output.append(current_attn_output) + + # cacluate attn_weights_inter for next block or cache + next_attn_weights_inter = torch.matmul( + (current_key_states * current_key_decay).transpose(-1, -2), current_value_states + ) + attn_weights_inter = attn_weights_inter * block_decay + next_attn_weights_inter + + else: + ratio = torch.exp(-self.slope_rate) + attn_output = [] + for i in range(seq_len): + current_query_states = query_states[:, :, i : i + 1] + current_key_states = key_states[:, :, i : i + 1] + current_value_states = value_states[:, :, i : i + 1] + + current_attn_weights_inter = torch.matmul(current_key_states.transpose(-1, -2), current_value_states) + attn_weights_inter = ratio * attn_weights_inter + current_attn_weights_inter + current_attn_output = torch.matmul(current_query_states, attn_weights_inter) + + attn_output.append(current_attn_output) + + # concatenate attention outputs over all blocks + attn_output = torch.cat(attn_output, dim=-2) + + # final output projection + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, seq_len, self.num_attention_heads * self.head_dim) + attn_output = self.norm(attn_output) + attn_output = F.sigmoid(self.output_gate(hidden_states)) * attn_output + attn_output = self.out_proj(attn_output) + + # update cache + if past_key_value is not None: + past_key_value.set_linear_cache(self.layer_idx, attn_weights_inter) + + return attn_output, attn_weights_inter + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class MiniMaxAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: MiniMaxConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class MiniMaxBlockSparseTop2MLP(nn.Module): + def __init__(self, config: MiniMaxConfig): + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +class MiniMaxSparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accommodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ + + def __init__(self, config): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + # gating + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + + self.experts = nn.ModuleList([MiniMaxBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) + + # Jitter parameters + self.jitter_noise = config.router_jitter_noise + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + if self.training and self.jitter_noise > 0: + hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + expert_hitted = (expert_mask.sum(dim=(-1, -2)) > 0).nonzero(as_tuple=True)[0].tolist() + for expert_idx in expert_hitted: + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class MiniMaxDecoderLayer(nn.Module): + def __init__(self, config: MiniMaxConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MiniMaxAttention(config, layer_idx) + + self.block_sparse_moe = MiniMaxSparseMoeBlock(config) + self.input_layernorm = MiniMaxRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MiniMaxRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.layer_idx = layer_idx + self.layer_type = config.layer_types[layer_idx] + self.mlp_alpha_factor = config.mlp_alpha_factor + self.mlp_beta_factor = config.mlp_beta_factor + + if self.layer_type == "linear_attention": + self.self_attn = MiniMaxLightningAttention(config, layer_idx) + self.attn_alpha_factor = config.linear_attn_alpha_factor + self.attn_beta_factor = config.linear_attn_beta_factor + else: + self.self_attn = MiniMaxAttention(config, layer_idx) + self.attn_alpha_factor = config.full_attn_alpha_factor + self.attn_beta_factor = config.full_attn_beta_factor + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + attention_mask (`torch.Tensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + hidden_states = self.input_layernorm(hidden_states) + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual * self.attn_alpha_factor + hidden_states * self.attn_beta_factor + + # Fully Connected + hidden_states = self.post_attention_layernorm(hidden_states) + residual = hidden_states + hidden_states, router_logits = self.block_sparse_moe(hidden_states) + hidden_states = residual * self.mlp_alpha_factor + hidden_states * self.mlp_beta_factor + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +@auto_docstring +class MiniMaxPreTrainedModel(PreTrainedModel): + config_class = MiniMaxConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MiniMaxDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True # Note: only supports MiniMaxCache + _supports_quantized_cache = False + _supports_static_cache = False + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, MiniMaxRMSNorm): + module.weight.data.fill_(1.0) + + +class MiniMaxRotaryEmbedding(nn.Module): + def __init__(self, config: MiniMaxConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@auto_docstring +class MiniMaxModel(MiniMaxPreTrainedModel): + def __init__(self, config: MiniMaxConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [MiniMaxDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = MiniMaxRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = MiniMaxRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> MoeModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = MiniMaxCache() + elif use_cache and not isinstance(past_key_values, MiniMaxCache): + raise ValueError( + f"MiniMax uses cache of its own and is not compatible with `past_key_values` of type {type(past_key_values)}." + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if decoder_layer.layer_type == "full_attention": + input_attention_mask = causal_mask + else: + # lightning attention uses original attention_mask, and uses it only for the first step + input_attention_mask = attention_mask + + layer_outputs = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=input_attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +@auto_docstring +class MiniMaxForCausalLM(MiniMaxPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = MiniMaxModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, MiniMaxForCausalLM + + >>> model = MiniMaxForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +@auto_docstring( + custom_intro=""" + The MiniMax Model transformer with a sequence classification head on top (linear layer). + + [`MiniMaxForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """ +) +class MiniMaxForSequenceClassification(MiniMaxPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MiniMaxModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> SequenceClassifierOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + transformer_outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + hidden_states = transformer_outputs.last_hidden_state + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@auto_docstring +class MiniMaxForTokenClassification(MiniMaxPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MiniMaxModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> TokenClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + sequence_output = outputs.last_hidden_state + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class MiniMaxForQuestionAnswering(MiniMaxPreTrainedModel): + base_model_prefix = "model" + + def __init__(self, config): + super().__init__(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + self.model = MiniMaxModel(config) # diff with Llama: transformer->model + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + **kwargs, + ) -> QuestionAnsweringModelOutput: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + + outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + sequence_output = outputs.last_hidden_state + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "MiniMaxPreTrainedModel", + "MiniMaxModel", + "MiniMaxForCausalLM", + "MiniMaxForSequenceClassification", + "MiniMaxForTokenClassification", + "MiniMaxForQuestionAnswering", +] diff --git a/src/transformers/models/minimax/modular_minimax.py b/src/transformers/models/minimax/modular_minimax.py new file mode 100644 index 00000000000..9a44d666563 --- /dev/null +++ b/src/transformers/models/minimax/modular_minimax.py @@ -0,0 +1,644 @@ +# coding=utf-8 +# Copyright 2025 MiniMaxAI and HuggingFace Inc. teams. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch MiniMax model.""" + +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...configuration_utils import layer_type_validation +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import MoeModelOutputWithPast +from ...processing_utils import Unpack +from ...utils import logging +from ..mixtral.configuration_mixtral import MixtralConfig +from ..mixtral.modeling_mixtral import ( + MixtralAttention, + MixtralDecoderLayer, + MixtralForCausalLM, + MixtralForQuestionAnswering, + MixtralForSequenceClassification, + MixtralForTokenClassification, + MixtralModel, + MixtralPreTrainedModel, + MixtralRMSNorm, +) + + +logger = logging.get_logger(__name__) + + +class MiniMaxConfig(MixtralConfig): + r""" + This is the configuration class to store the configuration of a [`MiniMaxModel`]. It is used to instantiate an + MiniMax model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the MiniMax. + + [MiniMaxAI/MiniMax-Text-01-hf](https://huggingface.co/MiniMaxAI/MiniMax-Text-01-hf) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the MiniMax model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MiniMaxModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`): + The attention head dimension. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. MiniMax's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `4096`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_experts_per_tok (`int`, *optional*, defaults to 2): + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter + num_local_experts (`int`, *optional*, defaults to 8): + Number of experts per Sparse MLP layer. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabeling this will also + allow the model to output the auxiliary loss. See [here]() for more details + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + router_jitter_noise (`float`, *optional*, defaults to 0.0): + Amount of noise to add to the router. + layer_types (`list`, *optional*): + Attention pattern for each layer. + block_size (`int`, *optional*, defaults to 256): + The length of each attention block, determining how queries, keys, and values + are grouped and processed for intra- and inter-block attention. + full_attn_alpha_factor (`float`, *optional*, defaults to 1): + Weight for residual value in residual connection after normal attention. + full_attn_beta_factor (`float`, *optional*, defaults to 1): + Weight for hidden state value in residual connection after normal attention. + linear_attn_alpha_factor (`float`, *optional*, defaults to 1): + Weight for residual value in residual connection after lightning attention. + linear_attn_beta_factor (`float`, *optional*, defaults to 1): + Weight for hidden state value in residual connection after lightning attention. + mlp_alpha_factor (`float`, *optional*, defaults to 1): + Weight for residual value in residual connection after MLP. + mlp_beta_factor (`float`, *optional*, defaults to 1): + Weight for hidden state value in residual connection after MLP. + + ```python + >>> from transformers import MiniMaxModel, MiniMaxConfig + + >>> # Initializing a MiniMax style configuration + >>> configuration = MiniMaxConfig() + + >>> # Initializing a model from the MiniMax style configuration + >>> model = MiniMaxModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + def __init__( + self, + layer_types=None, + block_size=256, + full_attn_alpha_factor=1, + full_attn_beta_factor=1, + linear_attn_alpha_factor=1, + linear_attn_beta_factor=1, + mlp_alpha_factor=1, + mlp_beta_factor=1, + **super_kwargs, + ): + super().__init__(**super_kwargs) + self.layer_types = layer_types + self.block_size = block_size + self.full_attn_alpha_factor = full_attn_alpha_factor + self.full_attn_beta_factor = full_attn_beta_factor + self.linear_attn_alpha_factor = linear_attn_alpha_factor + self.linear_attn_beta_factor = linear_attn_beta_factor + self.mlp_alpha_factor = mlp_alpha_factor + self.mlp_beta_factor = mlp_beta_factor + + if self.layer_types is None: + self.layer_types = [ + "full_attention" if bool((i + 1) % 2) else "linear_attention" for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types) + + +class MiniMaxRMSNorm(MixtralRMSNorm): + pass + + +class MiniMaxCache(DynamicCache): + def __init__(self): + super().__init__() + self.linear_cache: List[torch.Tensor] = [] + + def set_linear_cache(self, layer_idx, linear_cache): + # There may be skipped layers, fill them with empty lists + for _ in range(len(self.linear_cache), layer_idx + 1): + self.linear_cache.append([]) + self.linear_cache[layer_idx] = linear_cache + + def get_linear_cache(self, layer_idx: int): + if layer_idx < len(self): + return self.linear_cache[layer_idx] + return None + + def __len__(self): + return max(super().__len__(), len(self.linear_cache)) + + def __getitem__(self, layer_idx: int): + if layer_idx < len(self.linear_cache) and self.linear_cache[layer_idx] != []: + return (self.linear_cache[layer_idx],) + return super().__getitem__(layer_idx) + + def __iter__(self): + for layer_idx in range(len(self)): + yield self[layer_idx] + + def batch_repeat_interleave(self, repeats: int): + for layer_idx in range(len(self)): + if self.linear_cache[layer_idx] != []: + self.linear_cache[layer_idx] = self.linear_cache[layer_idx].repeat_interleave(repeats, dim=0) + else: + self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) + + def batch_select_indices(self, indices: torch.Tensor): + for layer_idx in range(len(self)): + if self.linear_cache[layer_idx] != []: + self.linear_cache[layer_idx] = self.linear_cache[layer_idx][indices, ...] + else: + self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] + self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] + + def crop(self, max_length: int): + raise RuntimeError("MiniMaxCache doesnot support `crop` method") + + +class MiniMaxLightningAttention(nn.Module): + def __init__(self, config: MiniMaxConfig, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + self.num_attention_heads = config.num_attention_heads + self.num_hidden_layers = config.num_hidden_layers + self.block_size = config.block_size + + self.act_fn = ACT2FN[config.hidden_act] + self.norm = MiniMaxRMSNorm(self.head_dim * self.num_attention_heads) + self.qkv_proj = nn.Linear(config.hidden_size, self.num_attention_heads * self.head_dim * 3, bias=False) + self.out_proj = nn.Linear(self.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.output_gate = nn.Linear(config.hidden_size, self.num_attention_heads * self.head_dim, bias=False) + + slope_rate = self.get_slope_rate() + query_decay, key_decay, diagonal_decay = self.decay_factors(slope_rate) + + self.register_buffer("slope_rate", slope_rate) + self.register_buffer("query_decay", query_decay) + self.register_buffer("key_decay", key_decay) + self.register_buffer("diagonal_decay", diagonal_decay) + + def get_slope_rate(self): + base = 1 / (2 ** (8 / self.num_attention_heads)) + exponent = torch.arange(self.num_attention_heads) + 1 + factor = 1 - self.layer_idx / (self.num_hidden_layers - 1 + 1e-5) + 1e-5 + + rate = base**exponent + rate = rate * factor + rate = rate[:, None, None] + + return rate + + def decay_factors(self, slope_rate): + block_size_range = torch.arange(self.block_size) + 1 + + query_decay = torch.exp(-slope_rate * block_size_range[:, None]) + key_decay = torch.exp(-slope_rate * (self.block_size - block_size_range[:, None])) + + diagonal_decay = block_size_range[:, None] - block_size_range[None, :] + diagonal_decay = diagonal_decay[None, None, :, :] + diagonal_decay = slope_rate * diagonal_decay + diagonal_decay = torch.where(diagonal_decay >= 0, -diagonal_decay, float("-inf")) + diagonal_decay = torch.exp(diagonal_decay) + + return query_decay, key_decay, diagonal_decay + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + batch_size, seq_len, hidden_size = hidden_states.shape + num_blocks = (seq_len + self.block_size - 1) // self.block_size + + qkv_states = self.act_fn(self.qkv_proj(hidden_states)) + qkv_states = qkv_states.reshape(batch_size, seq_len, self.num_attention_heads, 3 * self.head_dim) + + query_states, key_states, value_states = torch.split(qkv_states, self.head_dim, dim=3) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # calculated (K.T @ V) and saved as cache + attn_weights_inter = None + if past_key_value is not None: + attn_weights_inter = past_key_value.get_linear_cache(self.layer_idx) + + if attn_weights_inter is None: + attn_weights_inter = torch.zeros(batch_size, self.num_attention_heads, self.head_dim, self.head_dim).to( + value_states + ) + + # apply attention_mask + if attention_mask is not None: + attention_mask = attention_mask.to(dtype=torch.bool) # Ensure it's a boolean tensor + value_states = value_states.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(-1), 0) + + attn_output = [] + for i in range(num_blocks): + start_idx = i * self.block_size + end_idx = min(start_idx + self.block_size, seq_len) + current_block_size = end_idx - start_idx + + current_query_states = query_states[:, :, start_idx:end_idx] + current_key_states = key_states[:, :, start_idx:end_idx] + current_value_states = value_states[:, :, start_idx:end_idx] + + current_query_decay = self.query_decay[:, :current_block_size] + current_key_decay = self.key_decay[:, -current_block_size:] + current_diagonal_decay = self.diagonal_decay[:, :, :current_block_size, :current_block_size] + block_decay = torch.exp(-self.slope_rate * current_block_size) + + # intra: ( Q @ K.T ) @ V -> QK * V + attn_weights_intra = torch.matmul(current_query_states, current_key_states.transpose(-1, -2)) + attn_output_intra = torch.matmul(attn_weights_intra * current_diagonal_decay, current_value_states) + + # inter: Q @ ( K.T @ V ) -> Q * KV + attn_output_inter = torch.matmul(current_query_states * current_query_decay, attn_weights_inter) + + # final attention output + current_attn_output = attn_output_inter + attn_output_intra + attn_output.append(current_attn_output) + + # cacluate attn_weights_inter for next block or cache + next_attn_weights_inter = torch.matmul( + (current_key_states * current_key_decay).transpose(-1, -2), current_value_states + ) + attn_weights_inter = attn_weights_inter * block_decay + next_attn_weights_inter + + else: + ratio = torch.exp(-self.slope_rate) + attn_output = [] + for i in range(seq_len): + current_query_states = query_states[:, :, i : i + 1] + current_key_states = key_states[:, :, i : i + 1] + current_value_states = value_states[:, :, i : i + 1] + + current_attn_weights_inter = torch.matmul(current_key_states.transpose(-1, -2), current_value_states) + attn_weights_inter = ratio * attn_weights_inter + current_attn_weights_inter + current_attn_output = torch.matmul(current_query_states, attn_weights_inter) + + attn_output.append(current_attn_output) + + # concatenate attention outputs over all blocks + attn_output = torch.cat(attn_output, dim=-2) + + # final output projection + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, seq_len, self.num_attention_heads * self.head_dim) + attn_output = self.norm(attn_output) + attn_output = F.sigmoid(self.output_gate(hidden_states)) * attn_output + attn_output = self.out_proj(attn_output) + + # update cache + if past_key_value is not None: + past_key_value.set_linear_cache(self.layer_idx, attn_weights_inter) + + return attn_output, attn_weights_inter + + +class MiniMaxAttention(MixtralAttention): + pass + + +class MiniMaxDecoderLayer(MixtralDecoderLayer, GradientCheckpointingLayer): + def __init__(self, config: MiniMaxConfig, layer_idx: int): + super().__init__(config, layer_idx) + + self.layer_idx = layer_idx + self.layer_type = config.layer_types[layer_idx] + self.mlp_alpha_factor = config.mlp_alpha_factor + self.mlp_beta_factor = config.mlp_beta_factor + + if self.layer_type == "linear_attention": + self.self_attn = MiniMaxLightningAttention(config, layer_idx) + self.attn_alpha_factor = config.linear_attn_alpha_factor + self.attn_beta_factor = config.linear_attn_beta_factor + else: + self.self_attn = MiniMaxAttention(config, layer_idx) + self.attn_alpha_factor = config.full_attn_alpha_factor + self.attn_beta_factor = config.full_attn_beta_factor + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + attention_mask (`torch.Tensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + hidden_states = self.input_layernorm(hidden_states) + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual * self.attn_alpha_factor + hidden_states * self.attn_beta_factor + + # Fully Connected + hidden_states = self.post_attention_layernorm(hidden_states) + residual = hidden_states + hidden_states, router_logits = self.block_sparse_moe(hidden_states) + hidden_states = residual * self.mlp_alpha_factor + hidden_states * self.mlp_beta_factor + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +class MiniMaxPreTrainedModel(MixtralPreTrainedModel): + _supports_cache_class = True # Note: only supports MiniMaxCache + _supports_static_cache = False + _supports_quantized_cache = False + + +class MiniMaxModel(MixtralModel): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> MoeModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = MiniMaxCache() + elif use_cache and not isinstance(past_key_values, MiniMaxCache): + raise ValueError( + f"MiniMax uses cache of its own and is not compatible with `past_key_values` of type {type(past_key_values)}." + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if decoder_layer.layer_type == "full_attention": + input_attention_mask = causal_mask + else: + # lightning attention uses original attention_mask, and uses it only for the first step + input_attention_mask = attention_mask + + layer_outputs = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=input_attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + +class MiniMaxForCausalLM(MixtralForCausalLM): + def forward(self, **super_kwargs): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, MiniMaxForCausalLM + + >>> model = MiniMaxForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + return super().forward(**super_kwargs) + + +class MiniMaxForSequenceClassification(MixtralForSequenceClassification): + pass + + +class MiniMaxForTokenClassification(MixtralForTokenClassification): + pass + + +class MiniMaxForQuestionAnswering(MixtralForQuestionAnswering): + pass + + +__all__ = [ + "MiniMaxConfig", + "MiniMaxPreTrainedModel", + "MiniMaxModel", + "MiniMaxForCausalLM", + "MiniMaxForSequenceClassification", + "MiniMaxForTokenClassification", + "MiniMaxForQuestionAnswering", +] diff --git a/tests/models/minimax/__init__.py b/tests/models/minimax/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/models/minimax/test_modeling_minimax.py b/tests/models/minimax/test_modeling_minimax.py new file mode 100644 index 00000000000..2d03b01f736 --- /dev/null +++ b/tests/models/minimax/test_modeling_minimax.py @@ -0,0 +1,279 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch MiniMax model.""" + +import unittest + +import pytest + +from transformers import MiniMaxConfig, is_torch_available +from transformers.cache_utils import Cache +from transformers.testing_utils import ( + require_flash_attn, + require_torch, + require_torch_accelerator, + require_torch_gpu, + slow, + torch_device, +) + + +if is_torch_available(): + import torch + + from transformers import ( + MiniMaxForCausalLM, + MiniMaxForQuestionAnswering, + MiniMaxForSequenceClassification, + MiniMaxForTokenClassification, + MiniMaxModel, + ) + +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester + + +class MiniMaxModelTester(CausalLMModelTester): + config_class = MiniMaxConfig + if is_torch_available(): + base_model_class = MiniMaxModel + causal_lm_class = MiniMaxForCausalLM + sequence_class = MiniMaxForSequenceClassification + token_class = MiniMaxForTokenClassification + question_answering_class = MiniMaxForQuestionAnswering + + def __init__(self, parent, layer_types=None, block_size=3): + super().__init__(parent) + self.layer_types = layer_types + self.block_size = block_size + + +@require_torch +class MiniMaxModelTest(CausalLMModelTest, unittest.TestCase): + all_model_classes = ( + ( + MiniMaxModel, + MiniMaxForCausalLM, + MiniMaxForSequenceClassification, + MiniMaxForTokenClassification, + MiniMaxForQuestionAnswering, + ) + if is_torch_available() + else () + ) + pipeline_model_mapping = ( + { + "feature-extraction": MiniMaxModel, + "text-classification": MiniMaxForSequenceClassification, + "token-classification": MiniMaxForTokenClassification, + "text-generation": MiniMaxForCausalLM, + "question-answering": MiniMaxForQuestionAnswering, + } + if is_torch_available() + else {} + ) + + test_headmasking = False + test_pruning = False + model_tester_class = MiniMaxModelTester + + # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 + def is_pipeline_test_to_skip( + self, + pipeline_test_case_name, + config_class, + model_architecture, + tokenizer_name, + image_processor_name, + feature_extractor_name, + processor_name, + ): + return True + + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_inference_equivalence_right_padding(self): + self.skipTest(reason="MiniMax flash attention does not support right padding") + + def test_load_balancing_loss(self): + r""" + Let's make sure we can actually compute the loss and do a backward on it. + """ + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.num_local_experts = 8 + config.output_router_logits = True + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + model = MiniMaxForCausalLM(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask) + self.assertEqual(result.router_logits[0].shape, (91, config.num_local_experts)) + torch.testing.assert_close(result.aux_loss.cpu(), torch.tensor(2, dtype=torch.float32), rtol=1e-2, atol=1e-2) + + # First, we make sure that adding padding tokens doesn't change the loss + # loss(input_ids, attention_mask=None) == loss(input_ids + padding, attention_mask=attention_mask_with_padding) + pad_length = 1000 + # Add padding tokens (assume that pad_token_id=1) to input_ids + padding_block = torch.ones(input_ids.shape[0], pad_length, dtype=torch.int32).to(torch_device) + padded_input_ids = torch.cat((padding_block, input_ids), dim=1) # this is to simulate padding to the left + padded_attention_mask = padded_input_ids.ne(1).to(torch_device) + + padded_result = model(padded_input_ids, attention_mask=padded_attention_mask) + torch.testing.assert_close(result.aux_loss.cpu(), padded_result.aux_loss.cpu(), rtol=1e-4, atol=1e-4) + + # We make sure that the loss of including padding tokens != the loss without padding tokens + # if attention_mask=None --> we don't exclude padding tokens + include_padding_result = model(padded_input_ids, attention_mask=None) + + # This is to mimic torch.testing.assert_not_close + self.assertNotAlmostEqual(include_padding_result.aux_loss.item(), result.aux_loss.item()) + + def _check_attentions_for_generate( + self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values + ): + self.assertIsInstance(attentions, tuple) + self.assertListEqual( + [isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions) + ) + self.assertEqual(len(attentions), (output_length - prompt_length)) + use_cache = decoder_past_key_values is not None + + for generated_length, iter_attentions in enumerate(attentions): + # regardless of using cache, the first forward pass will have the full prompt as input + if use_cache and generated_length > 0: + model_input_length = 1 + else: + model_input_length = prompt_length + generated_length + + expected_shape = ( + batch_size, + config.num_attention_heads, + model_input_length, + prompt_length + generated_length, + ) + for layer_idx, layer_attention in enumerate(iter_attentions): + if config.layer_types[layer_idx] == "full_attention": + self.assertEqual(layer_attention.shape, expected_shape) + + def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config): + self.assertIsInstance(decoder_past_key_values, (tuple, Cache)) + + # (batch, head, seq_length, head_features) + key_value_cache_expected_shape = ( + batch_size, + config.num_key_value_heads, + cache_length, + config.hidden_size // config.num_attention_heads, + ) + # (batch, head, head_features, head_features) + linear_cache_expected_shape = ( + batch_size, + config.num_attention_heads, + config.hidden_size // config.num_attention_heads, + config.hidden_size // config.num_attention_heads, + ) + + for layer_idx in range(config.num_hidden_layers): + if config.layer_types[layer_idx] == "full_attention": + self.assertEqual(decoder_past_key_values[layer_idx][0].shape, key_value_cache_expected_shape) + self.assertEqual(decoder_past_key_values[layer_idx][1].shape, key_value_cache_expected_shape) + else: + self.assertEqual(decoder_past_key_values[layer_idx][0].shape, linear_cache_expected_shape) + + @pytest.mark.generate + def test_past_key_values_format(self, custom_all_cache_shapes=None): + """ + Test that the KV cache is formatted correctly. + """ + for model_class in self.all_generative_model_classes: + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + + model = model_class(config).to(torch_device) + model = model.eval() + if "use_cache" not in inputs: + inputs["use_cache"] = True + outputs = model(**inputs) + + past_kv = outputs["past_key_values"] + + batch_size, seq_length = inputs["input_ids"].shape + self._check_past_key_values_for_generate(batch_size, past_kv, seq_length, config) + + @unittest.skip(reason="MiniMaxCache doesnot support `crop()` method") + def test_prompt_lookup_decoding_matches_greedy_search(self): + pass + + @unittest.skip(reason="MiniMaxCache doesnot support `crop()` method") + def test_contrastive_generate_low_memory(self): + pass + + @unittest.skip(reason="MiniMaxCache doesnot support `crop()` method") + def test_assisted_decoding_sample(self): + pass + + @unittest.skip(reason="MiniMaxCache doesnot support `crop()` method") + def test_assisted_decoding_matches_greedy_search_0_random(self): + pass + + @unittest.skip(reason="MiniMaxCache doesnot support `crop()` method") + def test_assisted_decoding_matches_greedy_search_1_same(self): + pass + + @unittest.skip(reason="MiniMaxCache doesnot support `crop()` method") + def test_contrastive_generate_dict_outputs_use_cache(self): + pass + + +@require_torch +@require_torch_accelerator +@slow +class MiniMaxIntegrationTest(unittest.TestCase): + def test_small_model_logits(self): + model_id = "geetu040/MiniMax-tiny" + dummy_input = torch.LongTensor([[0, 1, 0], [0, 1, 0]]).to(torch_device) + + model = MiniMaxForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to( + torch_device + ) + expected_slice = torch.tensor( + [[1.0312, -0.5156, -0.3262], [-0.1152, 0.4336, 0.2412], [1.2188, -0.5898, -0.0381]] + ).to(torch_device) + + with torch.no_grad(): + logits = model(dummy_input).logits + + logits = logits.float() + + torch.testing.assert_close(logits[0, :3, :3], expected_slice, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(logits[1, :3, :3], expected_slice, atol=1e-3, rtol=1e-3) + + def test_small_model_generation(self): + model_id = "geetu040/MiniMax-tiny" + dummy_input = torch.LongTensor([[0, 1, 0], [0, 1, 0]]).to(torch_device) + + model = MiniMaxForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to( + torch_device + ) + expected_slice = ( + torch.tensor([[0, 1, 0, 933, 307, 3102, 2457, 1208], [0, 1, 0, 933, 307, 3102, 2457, 1208]]) + .to(torch.int64) + .to(torch_device) + ) + + outputs = model.generate(dummy_input, max_new_tokens=5, do_sample=False) + + torch.testing.assert_close(outputs, expected_slice, atol=1e-3, rtol=1e-3) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 446610db966..9011f341007 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3946,7 +3946,7 @@ class ModelTesterMixin: if not self.has_attentions: self.skipTest(reason="Model architecture does not support attentions") - WINDOW_ATTENTION_MODELS = ["mistral", "mixtral", "qwen2", "qwen_moe", "starcoder2"] + WINDOW_ATTENTION_MODELS = ["mistral", "mixtral", "minimax", "qwen2", "qwen_moe", "starcoder2"] if len(self.all_generative_model_classes) == 0: self.skipTest(f"No generative model classes for {self.__class__.__name__}")