Add support for MiniMax's MiniMax-Text-01 (#35831)

* end-to-end architecture

* lightning-attn: refactor, clean, optimize

* put minimax_text_01 in other files

* use latest __init__ standards and auto-generate modular

* support attention_mask for lightning-attn

* Revert "use latest __init__ standards and auto-generate modular"

This reverts commit d8d3c409d8.

* fix modular conversion

* pass both attention masks instead of tuple

* formatting

* Updated Dynamic Cache

* created MiniMaxText01Cache

* fix hardcoded slope_rate

* update attn_type_list in config

* fix lightning when use_cache=False

* copy tests from mixtral

* (checkpoint) all tests pass for normal attention

* fix all unittests

* fix import sorting

* fix consistency and formatting tests

* fix config

* update tests, since changes in main

* fix seq_len error

* create dummy docs

* fix checkpoint

* add checkpoint in config docstring

* run modular_conversion

* update docs

* fix checkpoint path and update tests

* fix ruff

* remove repeated expected_slice

* update docs

* rename "minimax-text-01" to "minimax"

* inherit config from mixtral

* remove from docs in other languages

* undo files that should be untouched

* move minimax to end in conversation docs

* use MiniMaxForCausalLM as it is

* ruff fixes

* run modular

* fix docstring example in causallm

* refactor attention loop and decay factors

* refactor config in modular

* run modular

* refactor cache

* rename static_cache to linear_cache

* make positional embeddings necessary

* remove unnecessary layernorms declarations

* fix import in tests

* refactor attention in next tokens

* remove outdated code

* formatting and modular

* update tests

* rename layernorm alpha/beta factors

* register decay factors as buffers

* remove unused declarations of decay factors

* update config for alpha/beta factors

* run modular

* remove head_dim in tests

* remove minimax from fx.py

* remove stuff that is not really needed

* update __init__

* update qkv torch.split

Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>

* fix qkv torch.split

* quality fixes

* remove mistakenly added dummy

* purge unused ModelTester code

* fix-copies

* run fix-copies

* fix head_dim

* write cache formatting tests

* remove postnorm

* avoid contiguous in attention current states

* update expected_slice

* add generation test for integration

* fix dtype in generation test

* update authors

* update with changes in main

* update graident checkpointing and minor fixes

* fix mutable attn_type_list

* rename: attn_type -> layer_type

* update for layer_types

* update integration tests

* update checkpoint

* clean overview in docs

---------

Co-authored-by: Shakib-IO <shakib.khan17@northsouth.edu>
Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
This commit is contained in:
Armaghan Shakir 2025-06-04 12:38:40 +05:00 committed by GitHub
parent 037acf1d10
commit 55736eea99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 2650 additions and 1 deletions

View File

@ -555,6 +555,8 @@
title: MegatronBERT title: MegatronBERT
- local: model_doc/megatron_gpt2 - local: model_doc/megatron_gpt2
title: MegatronGPT2 title: MegatronGPT2
- local: model_doc/minimax
title: MiniMax
- local: model_doc/mistral - local: model_doc/mistral
title: Mistral title: Mistral
- local: model_doc/mixtral - local: model_doc/mixtral

View File

@ -0,0 +1,189 @@
<!--Copyright 2025 MiniMaxAI and The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# MiniMax
## Overview
The DepthPro model was proposed in [MiniMax-01: Scaling Foundation Models with Lightning Attention](https://arxiv.org/abs/2501.08313) by MiniMax, Aonian Li, Bangwei Gong, Bo Yang, Boji Shan, Chang Liu, Cheng Zhu, Chunhao Zhang, Congchao Guo, Da Chen, Dong Li, Enwei Jiao, Gengxin Li, Guojun Zhang, Haohai Sun, Houze Dong, Jiadai Zhu, Jiaqi Zhuang, Jiayuan Song, Jin Zhu, Jingtao Han, Jingyang Li, Junbin Xie, Junhao Xu, Junjie Yan, Kaishun Zhang, Kecheng Xiao, Kexi Kang, Le Han, Leyang Wang, Lianfei Yu, Liheng Feng, Lin Zheng, Linbo Chai, Long Xing, Meizhi Ju, Mingyuan Chi, Mozhi Zhang, Peikai Huang, Pengcheng Niu, Pengfei Li, Pengyu Zhao, Qi Yang, Qidi Xu, Qiexiang Wang, Qin Wang, Qiuhui Li, Ruitao Leng, Shengmin Shi, Shuqi Yu, Sichen Li, Songquan Zhu, Tao Huang, Tianrun Liang, Weigao Sun, Weixuan Sun, Weiyu Cheng, Wenkai Li, Xiangjun Song, Xiao Su, Xiaodong Han, Xinjie Zhang, Xinzhu Hou, Xu Min, Xun Zou, Xuyang Shen, Yan Gong, Yingjie Zhu, Yipeng Zhou, Yiran Zhong, Yongyi Hu, Yuanxiang Fan, Yue Yu, Yufeng Yang, Yuhao Li, Yunan Huang, Yunji Li, Yunpeng Huang, Yunzhi Xu, Yuxin Mao, Zehan Li, Zekang Li, Zewei Tao, Zewen Ying, Zhaoyang Cong, Zhen Qin, Zhenhua Fan, Zhihang Yu, Zhuo Jiang, Zijia Wu.
The abstract from the paper is the following:
*We introduce MiniMax-01 series, including MiniMax-Text-01 and MiniMax-VL-01, which are comparable to top-tier models while offering superior capabilities in processing longer contexts. The core lies in lightning attention and its efficient scaling. To maximize computational capacity, we integrate it with Mixture of Experts (MoE), creating a model with 32 experts and 456 billion total parameters, of which 45.9 billion are activated for each token. We develop an optimized parallel strategy and highly efficient computation-communication overlap techniques for MoE and lightning attention. This approach enables us to conduct efficient training and inference on models with hundreds of billions of parameters across contexts spanning millions of tokens. The context window of MiniMax-Text-01 can reach up to 1 million tokens during training and extrapolate to 4 million tokens during inference at an affordable cost. Our vision-language model, MiniMax-VL-01 is built through continued training with 512 billion vision-language tokens. Experiments on both standard and in-house benchmarks show that our models match the performance of state-of-the-art models like GPT-4o and Claude-3.5-Sonnet while offering 20-32 times longer context window.*
### Architectural details
MiniMax is a powerful language model with 456 billion total parameters, of which 45.9 billion are activated per token. To better unlock the long context capabilities of the model, MiniMax adopts a hybrid architecture that combines Lightning Attention, Softmax Attention and Mixture-of-Experts (MoE). Leveraging advanced parallel strategies and innovative compute-communication overlap methods—such as Linear Attention Sequence Parallelism Plus (LASP+), varlen ring attention, Expert Tensor Parallel (ETP), etc., MiniMax's training context length is extended to 1 million tokens, and it can handle a context of up to 4 million tokens during the inference. On various academic benchmarks, MiniMax also demonstrates the performance of a top-tier model.
The architecture of MiniMax is briefly described as follows:
- Total Parameters: 456B
- Activated Parameters per Token: 45.9B
- Number Layers: 80
- Hybrid Attention: a softmax attention is positioned after every 7 lightning attention.
- Number of attention heads: 64
- Attention head dimension: 128
- Mixture of Experts:
- Number of experts: 32
- Expert hidden dimension: 9216
- Top-2 routing strategy
- Positional Encoding: Rotary Position Embedding (RoPE) applied to half of the attention head dimension with a base frequency of 10,000,000
- Hidden Size: 6144
- Vocab Size: 200,064
For more details refer to the [release blog post](https://www.minimaxi.com/en/news/minimax-01-series-2).
### License
`MiniMax` is released under the MINIMAX MODEL LICENSE AGREEMENT.
## Usage tips
The pre-trained model can be used as follows:
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> model = AutoModelForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf", device_map="auto")
>>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
>>> messages = [
... {"role": "user", "content": "What is your favourite condiment?"},
... {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
... {"role": "user", "content": "Do you have mayonnaise recipes?"}
... ]
>>> model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")
>>> generated_ids = model.generate(model_inputs, max_new_tokens=100, do_sample=True)
>>> tokenizer.batch_decode(generated_ids)[0]
"Mayonnaise can be made as follows: (...)"
```
As can be seen, the instruction-tuned model requires a [chat template](../chat_templating) to be applied to make sure the inputs are prepared in the right format.
## Speeding up MiniMax by using Flash Attention
The code snippets above showcase inference without any optimization tricks. However, one can drastically speed up the model by leveraging [Flash Attention](../perf_train_gpu_one#flash-attention-2), which is a faster implementation of the attention mechanism used inside the model.
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
```bash
pip install -U flash-attn --no-build-isolation
```
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of the [flash attention repository](https://github.com/Dao-AILab/flash-attention). Make also sure to load your model in half-precision (e.g. `torch.float16`)
To load and run a model using Flash Attention-2, refer to the snippet below:
```python
>>> import torch
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> model = AutoModelForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf", torch_dtype=torch.float16, attn_implementation="flash_attention_2", device_map="auto")
>>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
>>> prompt = "My favourite condiment is"
>>> model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
>>> model.to(device)
>>> generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
>>> tokenizer.batch_decode(generated_ids)[0]
"The expected output"
```
### Sliding window Attention
The current implementation supports the sliding window attention mechanism and memory efficient cache management.
To enable sliding window attention, just make sure to have a `flash-attn` version that is compatible with sliding window attention (`>=2.3.0`).
The Flash Attention-2 model uses also a more memory efficient cache slicing mechanism - as recommended per the official implementation of Mistral model that use rolling cache mechanism we keep the cache size fixed (`self.config.sliding_window`), support batched generation only for `padding_side="left"` and use the absolute position of the current token to compute the positional embedding.
## Shrinking down MiniMax using quantization
As the MiniMax model has 456 billion parameters, that would require about 912GB of GPU RAM in half precision (float16), since each parameter is stored in 2 bytes. However, one can shrink down the size of the model using [quantization](../quantization.md). If the model is quantized to 4 bits (or half a byte per parameter), about 228 GB of RAM is required.
Quantizing a model is as simple as passing a `quantization_config` to the model. Below, we'll leverage the bitsandbytes quantization library (but refer to [this page](../quantization.md) for alternative quantization methods):
```python
>>> import torch
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
>>> # specify how to quantize the model
>>> quantization_config = BitsAndBytesConfig(
... load_in_4bit=True,
... bnb_4bit_quant_type="nf4",
... bnb_4bit_compute_dtype="torch.float16",
... )
>>> model = AutoModelForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf", quantization_config=True, device_map="auto")
>>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
>>> prompt = "My favourite condiment is"
>>> messages = [
... {"role": "user", "content": "What is your favourite condiment?"},
... {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
... {"role": "user", "content": "Do you have mayonnaise recipes?"}
... ]
>>> model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")
>>> generated_ids = model.generate(model_inputs, max_new_tokens=100, do_sample=True)
>>> tokenizer.batch_decode(generated_ids)[0]
"The expected output"
```
This model was contributed by [geetu040](https://github.com/geetu040).
The original code can be found [here](https://huggingface.co/MiniMaxAI/MiniMax-Text-01-hf/blob/main/modeling_minimax.py).
## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with MiniMax. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
<PipelineTag pipeline="text-generation"/>
- The [Alignment Handbook](https://github.com/huggingface/alignment-handbook) by Hugging Face includes scripts and recipes to perform supervised fine-tuning (SFT) and direct preference optimization with Mistral-7B. This includes scripts for full fine-tuning, QLoRa on a single GPU as well as multi-GPU fine-tuning.
- [Causal language modeling task guide](../tasks/language_modeling)
## MiniMaxConfig
[[autodoc]] MiniMaxConfig
## MiniMaxModel
[[autodoc]] MiniMaxModel
- forward
## MiniMaxForCausalLM
[[autodoc]] MiniMaxForCausalLM
- forward
## MiniMaxForSequenceClassification
[[autodoc]] MiniMaxForSequenceClassification
- forward
## MiniMaxForTokenClassification
[[autodoc]] MiniMaxForTokenClassification
- forward
## MiniMaxForQuestionAnswering
[[autodoc]] MiniMaxForQuestionAnswering
- forward

View File

@ -1218,6 +1218,7 @@ ALLOWED_LAYER_TYPES = (
"full_attention", "full_attention",
"sliding_attention", "sliding_attention",
"chunked_attention", "chunked_attention",
"linear_attention", # used in minimax
) )

View File

@ -1976,6 +1976,7 @@ class GenerationMixin(ContinuousMixin):
and "jamba" not in self.__class__.__name__.lower() and "jamba" not in self.__class__.__name__.lower()
and "zamba" not in self.__class__.__name__.lower() and "zamba" not in self.__class__.__name__.lower()
and "bamba" not in self.__class__.__name__.lower() and "bamba" not in self.__class__.__name__.lower()
and "minimax" not in self.__class__.__name__.lower()
) )
def _prepare_cache_for_generation( def _prepare_cache_for_generation(

View File

@ -185,6 +185,7 @@ if TYPE_CHECKING:
from .megatron_gpt2 import * from .megatron_gpt2 import *
from .mgp_str import * from .mgp_str import *
from .mimi import * from .mimi import *
from .minimax import *
from .mistral import * from .mistral import *
from .mistral3 import * from .mistral3 import *
from .mixtral import * from .mixtral import *

View File

@ -211,6 +211,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
("megatron-bert", "MegatronBertConfig"), ("megatron-bert", "MegatronBertConfig"),
("mgp-str", "MgpstrConfig"), ("mgp-str", "MgpstrConfig"),
("mimi", "MimiConfig"), ("mimi", "MimiConfig"),
("minimax", "MiniMaxConfig"),
("mistral", "MistralConfig"), ("mistral", "MistralConfig"),
("mistral3", "Mistral3Config"), ("mistral3", "Mistral3Config"),
("mixtral", "MixtralConfig"), ("mixtral", "MixtralConfig"),
@ -586,6 +587,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
("megatron_gpt2", "Megatron-GPT2"), ("megatron_gpt2", "Megatron-GPT2"),
("mgp-str", "MGP-STR"), ("mgp-str", "MGP-STR"),
("mimi", "Mimi"), ("mimi", "Mimi"),
("minimax", "MiniMax"),
("mistral", "Mistral"), ("mistral", "Mistral"),
("mistral3", "Mistral3"), ("mistral3", "Mistral3"),
("mixtral", "Mixtral"), ("mixtral", "Mixtral"),

View File

@ -201,6 +201,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("megatron-bert", "MegatronBertModel"), ("megatron-bert", "MegatronBertModel"),
("mgp-str", "MgpstrForSceneTextRecognition"), ("mgp-str", "MgpstrForSceneTextRecognition"),
("mimi", "MimiModel"), ("mimi", "MimiModel"),
("minimax", "MiniMaxModel"),
("mistral", "MistralModel"), ("mistral", "MistralModel"),
("mistral3", "Mistral3Model"), ("mistral3", "Mistral3Model"),
("mixtral", "MixtralModel"), ("mixtral", "MixtralModel"),
@ -594,6 +595,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("mbart", "MBartForCausalLM"), ("mbart", "MBartForCausalLM"),
("mega", "MegaForCausalLM"), ("mega", "MegaForCausalLM"),
("megatron-bert", "MegatronBertForCausalLM"), ("megatron-bert", "MegatronBertForCausalLM"),
("minimax", "MiniMaxForCausalLM"),
("mistral", "MistralForCausalLM"), ("mistral", "MistralForCausalLM"),
("mixtral", "MixtralForCausalLM"), ("mixtral", "MixtralForCausalLM"),
("mllama", "MllamaForCausalLM"), ("mllama", "MllamaForCausalLM"),
@ -1106,6 +1108,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("mbart", "MBartForSequenceClassification"), ("mbart", "MBartForSequenceClassification"),
("mega", "MegaForSequenceClassification"), ("mega", "MegaForSequenceClassification"),
("megatron-bert", "MegatronBertForSequenceClassification"), ("megatron-bert", "MegatronBertForSequenceClassification"),
("minimax", "MiniMaxForSequenceClassification"),
("mistral", "MistralForSequenceClassification"), ("mistral", "MistralForSequenceClassification"),
("mixtral", "MixtralForSequenceClassification"), ("mixtral", "MixtralForSequenceClassification"),
("mobilebert", "MobileBertForSequenceClassification"), ("mobilebert", "MobileBertForSequenceClassification"),
@ -1197,6 +1200,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
("mbart", "MBartForQuestionAnswering"), ("mbart", "MBartForQuestionAnswering"),
("mega", "MegaForQuestionAnswering"), ("mega", "MegaForQuestionAnswering"),
("megatron-bert", "MegatronBertForQuestionAnswering"), ("megatron-bert", "MegatronBertForQuestionAnswering"),
("minimax", "MiniMaxForQuestionAnswering"),
("mistral", "MistralForQuestionAnswering"), ("mistral", "MistralForQuestionAnswering"),
("mixtral", "MixtralForQuestionAnswering"), ("mixtral", "MixtralForQuestionAnswering"),
("mobilebert", "MobileBertForQuestionAnswering"), ("mobilebert", "MobileBertForQuestionAnswering"),
@ -1303,6 +1307,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("markuplm", "MarkupLMForTokenClassification"), ("markuplm", "MarkupLMForTokenClassification"),
("mega", "MegaForTokenClassification"), ("mega", "MegaForTokenClassification"),
("megatron-bert", "MegatronBertForTokenClassification"), ("megatron-bert", "MegatronBertForTokenClassification"),
("minimax", "MiniMaxForTokenClassification"),
("mistral", "MistralForTokenClassification"), ("mistral", "MistralForTokenClassification"),
("mixtral", "MixtralForTokenClassification"), ("mixtral", "MixtralForTokenClassification"),
("mobilebert", "MobileBertForTokenClassification"), ("mobilebert", "MobileBertForTokenClassification"),

View File

@ -342,6 +342,13 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
("mega", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), ("mega", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("mgp-str", ("MgpstrTokenizer", None)), ("mgp-str", ("MgpstrTokenizer", None)),
(
"minimax",
(
"GPT2Tokenizer" if is_sentencepiece_available() else None,
"GPT2TokenizerFast" if is_tokenizers_available() else None,
),
),
( (
"mistral", "mistral",
( (

View File

@ -0,0 +1,29 @@
# coding=utf-8
# Copyright 2025 MiniMaxAI and HuggingFace Inc. teams. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_minimax import *
from .modeling_minimax import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,230 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/minimax/modular_minimax.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_minimax.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 MiniMaxAI and HuggingFace Inc. teams. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...configuration_utils import PretrainedConfig, layer_type_validation
class MiniMaxConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MiniMaxModel`]. It is used to instantiate an
MiniMax model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the MiniMax.
[MiniMaxAI/MiniMax-Text-01-hf](https://huggingface.co/MiniMaxAI/MiniMax-Text-01-hf)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the MiniMax model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`MiniMaxModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 14336):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
The attention head dimension.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
The maximum sequence length that this model might ever be used with. MiniMax's sliding window attention
allows sequence of up to 4096*32 tokens.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
The id of the padding token.
bos_token_id (`int`, *optional*, defaults to 1):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 2):
The id of the "end-of-sequence" token.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 1000000.0):
The base period of the RoPE embeddings.
sliding_window (`int`, *optional*):
Sliding window attention window size. If not specified, will default to `4096`.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
num_experts_per_tok (`int`, *optional*, defaults to 2):
The number of experts to route per-token, can be also interpreted as the `top-k` routing
parameter
num_local_experts (`int`, *optional*, defaults to 8):
Number of experts per Sparse MLP layer.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabeling this will also
allow the model to output the auxiliary loss. See [here]() for more details
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
The aux loss factor for the total loss.
router_jitter_noise (`float`, *optional*, defaults to 0.0):
Amount of noise to add to the router.
layer_types (`list`, *optional*):
Attention pattern for each layer.
block_size (`int`, *optional*, defaults to 256):
The length of each attention block, determining how queries, keys, and values
are grouped and processed for intra- and inter-block attention.
full_attn_alpha_factor (`float`, *optional*, defaults to 1):
Weight for residual value in residual connection after normal attention.
full_attn_beta_factor (`float`, *optional*, defaults to 1):
Weight for hidden state value in residual connection after normal attention.
linear_attn_alpha_factor (`float`, *optional*, defaults to 1):
Weight for residual value in residual connection after lightning attention.
linear_attn_beta_factor (`float`, *optional*, defaults to 1):
Weight for hidden state value in residual connection after lightning attention.
mlp_alpha_factor (`float`, *optional*, defaults to 1):
Weight for residual value in residual connection after MLP.
mlp_beta_factor (`float`, *optional*, defaults to 1):
Weight for hidden state value in residual connection after MLP.
```python
>>> from transformers import MiniMaxModel, MiniMaxConfig
>>> # Initializing a MiniMax style configuration
>>> configuration = MiniMaxConfig()
>>> # Initializing a model from the MiniMax style configuration
>>> model = MiniMaxModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "minimax"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.block_sparse_moe.gate": "colwise_rep", # we need to replicate here to correctly route experts
"layers.*.block_sparse_moe.experts.*.w1": "colwise",
"layers.*.block_sparse_moe.experts.*.w2": "rowwise",
"layers.*.block_sparse_moe.experts.*.w3": "colwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
head_dim=None,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=1e6,
sliding_window=None,
attention_dropout=0.0,
num_experts_per_tok=2,
num_local_experts=8,
output_router_logits=False,
router_aux_loss_coef=0.001,
router_jitter_noise=0.0,
layer_types=None,
block_size=256,
full_attn_alpha_factor=1,
full_attn_beta_factor=1,
linear_attn_alpha_factor=1,
linear_attn_beta_factor=1,
mlp_alpha_factor=1,
mlp_beta_factor=1,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.head_dim = head_dim
self.num_experts_per_tok = num_experts_per_tok
self.num_local_experts = num_local_experts
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.router_jitter_noise = router_jitter_noise
self.layer_types = layer_types
self.block_size = block_size
self.full_attn_alpha_factor = full_attn_alpha_factor
self.full_attn_beta_factor = full_attn_beta_factor
self.linear_attn_alpha_factor = linear_attn_alpha_factor
self.linear_attn_beta_factor = linear_attn_beta_factor
self.mlp_alpha_factor = mlp_alpha_factor
self.mlp_beta_factor = mlp_beta_factor
if self.layer_types is None:
self.layer_types = [
"full_attention" if bool((i + 1) % 2) else "linear_attention" for i in range(self.num_hidden_layers)
]
layer_type_validation(self.layer_types)
__all__ = ["MiniMaxConfig"]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,644 @@
# coding=utf-8
# Copyright 2025 MiniMaxAI and HuggingFace Inc. teams. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch MiniMax model."""
from typing import List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...configuration_utils import layer_type_validation
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import MoeModelOutputWithPast
from ...processing_utils import Unpack
from ...utils import logging
from ..mixtral.configuration_mixtral import MixtralConfig
from ..mixtral.modeling_mixtral import (
MixtralAttention,
MixtralDecoderLayer,
MixtralForCausalLM,
MixtralForQuestionAnswering,
MixtralForSequenceClassification,
MixtralForTokenClassification,
MixtralModel,
MixtralPreTrainedModel,
MixtralRMSNorm,
)
logger = logging.get_logger(__name__)
class MiniMaxConfig(MixtralConfig):
r"""
This is the configuration class to store the configuration of a [`MiniMaxModel`]. It is used to instantiate an
MiniMax model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the MiniMax.
[MiniMaxAI/MiniMax-Text-01-hf](https://huggingface.co/MiniMaxAI/MiniMax-Text-01-hf)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the MiniMax model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`MiniMaxModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 14336):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
The attention head dimension.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
The maximum sequence length that this model might ever be used with. MiniMax's sliding window attention
allows sequence of up to 4096*32 tokens.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
The id of the padding token.
bos_token_id (`int`, *optional*, defaults to 1):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 2):
The id of the "end-of-sequence" token.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 1000000.0):
The base period of the RoPE embeddings.
sliding_window (`int`, *optional*):
Sliding window attention window size. If not specified, will default to `4096`.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
num_experts_per_tok (`int`, *optional*, defaults to 2):
The number of experts to route per-token, can be also interpreted as the `top-k` routing
parameter
num_local_experts (`int`, *optional*, defaults to 8):
Number of experts per Sparse MLP layer.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabeling this will also
allow the model to output the auxiliary loss. See [here]() for more details
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
The aux loss factor for the total loss.
router_jitter_noise (`float`, *optional*, defaults to 0.0):
Amount of noise to add to the router.
layer_types (`list`, *optional*):
Attention pattern for each layer.
block_size (`int`, *optional*, defaults to 256):
The length of each attention block, determining how queries, keys, and values
are grouped and processed for intra- and inter-block attention.
full_attn_alpha_factor (`float`, *optional*, defaults to 1):
Weight for residual value in residual connection after normal attention.
full_attn_beta_factor (`float`, *optional*, defaults to 1):
Weight for hidden state value in residual connection after normal attention.
linear_attn_alpha_factor (`float`, *optional*, defaults to 1):
Weight for residual value in residual connection after lightning attention.
linear_attn_beta_factor (`float`, *optional*, defaults to 1):
Weight for hidden state value in residual connection after lightning attention.
mlp_alpha_factor (`float`, *optional*, defaults to 1):
Weight for residual value in residual connection after MLP.
mlp_beta_factor (`float`, *optional*, defaults to 1):
Weight for hidden state value in residual connection after MLP.
```python
>>> from transformers import MiniMaxModel, MiniMaxConfig
>>> # Initializing a MiniMax style configuration
>>> configuration = MiniMaxConfig()
>>> # Initializing a model from the MiniMax style configuration
>>> model = MiniMaxModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
def __init__(
self,
layer_types=None,
block_size=256,
full_attn_alpha_factor=1,
full_attn_beta_factor=1,
linear_attn_alpha_factor=1,
linear_attn_beta_factor=1,
mlp_alpha_factor=1,
mlp_beta_factor=1,
**super_kwargs,
):
super().__init__(**super_kwargs)
self.layer_types = layer_types
self.block_size = block_size
self.full_attn_alpha_factor = full_attn_alpha_factor
self.full_attn_beta_factor = full_attn_beta_factor
self.linear_attn_alpha_factor = linear_attn_alpha_factor
self.linear_attn_beta_factor = linear_attn_beta_factor
self.mlp_alpha_factor = mlp_alpha_factor
self.mlp_beta_factor = mlp_beta_factor
if self.layer_types is None:
self.layer_types = [
"full_attention" if bool((i + 1) % 2) else "linear_attention" for i in range(self.num_hidden_layers)
]
layer_type_validation(self.layer_types)
class MiniMaxRMSNorm(MixtralRMSNorm):
pass
class MiniMaxCache(DynamicCache):
def __init__(self):
super().__init__()
self.linear_cache: List[torch.Tensor] = []
def set_linear_cache(self, layer_idx, linear_cache):
# There may be skipped layers, fill them with empty lists
for _ in range(len(self.linear_cache), layer_idx + 1):
self.linear_cache.append([])
self.linear_cache[layer_idx] = linear_cache
def get_linear_cache(self, layer_idx: int):
if layer_idx < len(self):
return self.linear_cache[layer_idx]
return None
def __len__(self):
return max(super().__len__(), len(self.linear_cache))
def __getitem__(self, layer_idx: int):
if layer_idx < len(self.linear_cache) and self.linear_cache[layer_idx] != []:
return (self.linear_cache[layer_idx],)
return super().__getitem__(layer_idx)
def __iter__(self):
for layer_idx in range(len(self)):
yield self[layer_idx]
def batch_repeat_interleave(self, repeats: int):
for layer_idx in range(len(self)):
if self.linear_cache[layer_idx] != []:
self.linear_cache[layer_idx] = self.linear_cache[layer_idx].repeat_interleave(repeats, dim=0)
else:
self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)
def batch_select_indices(self, indices: torch.Tensor):
for layer_idx in range(len(self)):
if self.linear_cache[layer_idx] != []:
self.linear_cache[layer_idx] = self.linear_cache[layer_idx][indices, ...]
else:
self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
def crop(self, max_length: int):
raise RuntimeError("MiniMaxCache doesnot support `crop` method")
class MiniMaxLightningAttention(nn.Module):
def __init__(self, config: MiniMaxConfig, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
self.num_attention_heads = config.num_attention_heads
self.num_hidden_layers = config.num_hidden_layers
self.block_size = config.block_size
self.act_fn = ACT2FN[config.hidden_act]
self.norm = MiniMaxRMSNorm(self.head_dim * self.num_attention_heads)
self.qkv_proj = nn.Linear(config.hidden_size, self.num_attention_heads * self.head_dim * 3, bias=False)
self.out_proj = nn.Linear(self.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
self.output_gate = nn.Linear(config.hidden_size, self.num_attention_heads * self.head_dim, bias=False)
slope_rate = self.get_slope_rate()
query_decay, key_decay, diagonal_decay = self.decay_factors(slope_rate)
self.register_buffer("slope_rate", slope_rate)
self.register_buffer("query_decay", query_decay)
self.register_buffer("key_decay", key_decay)
self.register_buffer("diagonal_decay", diagonal_decay)
def get_slope_rate(self):
base = 1 / (2 ** (8 / self.num_attention_heads))
exponent = torch.arange(self.num_attention_heads) + 1
factor = 1 - self.layer_idx / (self.num_hidden_layers - 1 + 1e-5) + 1e-5
rate = base**exponent
rate = rate * factor
rate = rate[:, None, None]
return rate
def decay_factors(self, slope_rate):
block_size_range = torch.arange(self.block_size) + 1
query_decay = torch.exp(-slope_rate * block_size_range[:, None])
key_decay = torch.exp(-slope_rate * (self.block_size - block_size_range[:, None]))
diagonal_decay = block_size_range[:, None] - block_size_range[None, :]
diagonal_decay = diagonal_decay[None, None, :, :]
diagonal_decay = slope_rate * diagonal_decay
diagonal_decay = torch.where(diagonal_decay >= 0, -diagonal_decay, float("-inf"))
diagonal_decay = torch.exp(diagonal_decay)
return query_decay, key_decay, diagonal_decay
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
batch_size, seq_len, hidden_size = hidden_states.shape
num_blocks = (seq_len + self.block_size - 1) // self.block_size
qkv_states = self.act_fn(self.qkv_proj(hidden_states))
qkv_states = qkv_states.reshape(batch_size, seq_len, self.num_attention_heads, 3 * self.head_dim)
query_states, key_states, value_states = torch.split(qkv_states, self.head_dim, dim=3)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# calculated (K.T @ V) and saved as cache
attn_weights_inter = None
if past_key_value is not None:
attn_weights_inter = past_key_value.get_linear_cache(self.layer_idx)
if attn_weights_inter is None:
attn_weights_inter = torch.zeros(batch_size, self.num_attention_heads, self.head_dim, self.head_dim).to(
value_states
)
# apply attention_mask
if attention_mask is not None:
attention_mask = attention_mask.to(dtype=torch.bool) # Ensure it's a boolean tensor
value_states = value_states.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(-1), 0)
attn_output = []
for i in range(num_blocks):
start_idx = i * self.block_size
end_idx = min(start_idx + self.block_size, seq_len)
current_block_size = end_idx - start_idx
current_query_states = query_states[:, :, start_idx:end_idx]
current_key_states = key_states[:, :, start_idx:end_idx]
current_value_states = value_states[:, :, start_idx:end_idx]
current_query_decay = self.query_decay[:, :current_block_size]
current_key_decay = self.key_decay[:, -current_block_size:]
current_diagonal_decay = self.diagonal_decay[:, :, :current_block_size, :current_block_size]
block_decay = torch.exp(-self.slope_rate * current_block_size)
# intra: ( Q @ K.T ) @ V -> QK * V
attn_weights_intra = torch.matmul(current_query_states, current_key_states.transpose(-1, -2))
attn_output_intra = torch.matmul(attn_weights_intra * current_diagonal_decay, current_value_states)
# inter: Q @ ( K.T @ V ) -> Q * KV
attn_output_inter = torch.matmul(current_query_states * current_query_decay, attn_weights_inter)
# final attention output
current_attn_output = attn_output_inter + attn_output_intra
attn_output.append(current_attn_output)
# cacluate attn_weights_inter for next block or cache
next_attn_weights_inter = torch.matmul(
(current_key_states * current_key_decay).transpose(-1, -2), current_value_states
)
attn_weights_inter = attn_weights_inter * block_decay + next_attn_weights_inter
else:
ratio = torch.exp(-self.slope_rate)
attn_output = []
for i in range(seq_len):
current_query_states = query_states[:, :, i : i + 1]
current_key_states = key_states[:, :, i : i + 1]
current_value_states = value_states[:, :, i : i + 1]
current_attn_weights_inter = torch.matmul(current_key_states.transpose(-1, -2), current_value_states)
attn_weights_inter = ratio * attn_weights_inter + current_attn_weights_inter
current_attn_output = torch.matmul(current_query_states, attn_weights_inter)
attn_output.append(current_attn_output)
# concatenate attention outputs over all blocks
attn_output = torch.cat(attn_output, dim=-2)
# final output projection
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(batch_size, seq_len, self.num_attention_heads * self.head_dim)
attn_output = self.norm(attn_output)
attn_output = F.sigmoid(self.output_gate(hidden_states)) * attn_output
attn_output = self.out_proj(attn_output)
# update cache
if past_key_value is not None:
past_key_value.set_linear_cache(self.layer_idx, attn_weights_inter)
return attn_output, attn_weights_inter
class MiniMaxAttention(MixtralAttention):
pass
class MiniMaxDecoderLayer(MixtralDecoderLayer, GradientCheckpointingLayer):
def __init__(self, config: MiniMaxConfig, layer_idx: int):
super().__init__(config, layer_idx)
self.layer_idx = layer_idx
self.layer_type = config.layer_types[layer_idx]
self.mlp_alpha_factor = config.mlp_alpha_factor
self.mlp_beta_factor = config.mlp_beta_factor
if self.layer_type == "linear_attention":
self.self_attn = MiniMaxLightningAttention(config, layer_idx)
self.attn_alpha_factor = config.linear_attn_alpha_factor
self.attn_beta_factor = config.linear_attn_beta_factor
else:
self.self_attn = MiniMaxAttention(config, layer_idx)
self.attn_alpha_factor = config.full_attn_alpha_factor
self.attn_beta_factor = config.full_attn_beta_factor
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
output_router_logits: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
attention_mask (`torch.Tensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_router_logits (`bool`, *optional*):
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
should not be returned during inference.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
hidden_states = self.input_layernorm(hidden_states)
residual = hidden_states
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual * self.attn_alpha_factor + hidden_states * self.attn_beta_factor
# Fully Connected
hidden_states = self.post_attention_layernorm(hidden_states)
residual = hidden_states
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
hidden_states = residual * self.mlp_alpha_factor + hidden_states * self.mlp_beta_factor
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if output_router_logits:
outputs += (router_logits,)
return outputs
class MiniMaxPreTrainedModel(MixtralPreTrainedModel):
_supports_cache_class = True # Note: only supports MiniMaxCache
_supports_static_cache = False
_supports_quantized_cache = False
class MiniMaxModel(MixtralModel):
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> MoeModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if use_cache and past_key_values is None:
past_key_values = MiniMaxCache()
elif use_cache and not isinstance(past_key_values, MiniMaxCache):
raise ValueError(
f"MiniMax uses cache of its own and is not compatible with `past_key_values` of type {type(past_key_values)}."
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
causal_mask = mask_function(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_router_logits = () if output_router_logits else None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if decoder_layer.layer_type == "full_attention":
input_attention_mask = causal_mask
else:
# lightning attention uses original attention_mask, and uses it only for the first step
input_attention_mask = attention_mask
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=input_attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
output_router_logits=output_router_logits,
use_cache=use_cache,
cache_position=cache_position,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if output_router_logits:
all_router_logits += (layer_outputs[-1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attns,
router_logits=all_router_logits,
)
class MiniMaxForCausalLM(MixtralForCausalLM):
def forward(self, **super_kwargs):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python
>>> from transformers import AutoTokenizer, MiniMaxForCausalLM
>>> model = MiniMaxForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
return super().forward(**super_kwargs)
class MiniMaxForSequenceClassification(MixtralForSequenceClassification):
pass
class MiniMaxForTokenClassification(MixtralForTokenClassification):
pass
class MiniMaxForQuestionAnswering(MixtralForQuestionAnswering):
pass
__all__ = [
"MiniMaxConfig",
"MiniMaxPreTrainedModel",
"MiniMaxModel",
"MiniMaxForCausalLM",
"MiniMaxForSequenceClassification",
"MiniMaxForTokenClassification",
"MiniMaxForQuestionAnswering",
]

View File

View File

@ -0,0 +1,279 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch MiniMax model."""
import unittest
import pytest
from transformers import MiniMaxConfig, is_torch_available
from transformers.cache_utils import Cache
from transformers.testing_utils import (
require_flash_attn,
require_torch,
require_torch_accelerator,
require_torch_gpu,
slow,
torch_device,
)
if is_torch_available():
import torch
from transformers import (
MiniMaxForCausalLM,
MiniMaxForQuestionAnswering,
MiniMaxForSequenceClassification,
MiniMaxForTokenClassification,
MiniMaxModel,
)
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
class MiniMaxModelTester(CausalLMModelTester):
config_class = MiniMaxConfig
if is_torch_available():
base_model_class = MiniMaxModel
causal_lm_class = MiniMaxForCausalLM
sequence_class = MiniMaxForSequenceClassification
token_class = MiniMaxForTokenClassification
question_answering_class = MiniMaxForQuestionAnswering
def __init__(self, parent, layer_types=None, block_size=3):
super().__init__(parent)
self.layer_types = layer_types
self.block_size = block_size
@require_torch
class MiniMaxModelTest(CausalLMModelTest, unittest.TestCase):
all_model_classes = (
(
MiniMaxModel,
MiniMaxForCausalLM,
MiniMaxForSequenceClassification,
MiniMaxForTokenClassification,
MiniMaxForQuestionAnswering,
)
if is_torch_available()
else ()
)
pipeline_model_mapping = (
{
"feature-extraction": MiniMaxModel,
"text-classification": MiniMaxForSequenceClassification,
"token-classification": MiniMaxForTokenClassification,
"text-generation": MiniMaxForCausalLM,
"question-answering": MiniMaxForQuestionAnswering,
}
if is_torch_available()
else {}
)
test_headmasking = False
test_pruning = False
model_tester_class = MiniMaxModelTester
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip(
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return True
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest(reason="MiniMax flash attention does not support right padding")
def test_load_balancing_loss(self):
r"""
Let's make sure we can actually compute the loss and do a backward on it.
"""
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
config.num_local_experts = 8
config.output_router_logits = True
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
model = MiniMaxForCausalLM(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask)
self.assertEqual(result.router_logits[0].shape, (91, config.num_local_experts))
torch.testing.assert_close(result.aux_loss.cpu(), torch.tensor(2, dtype=torch.float32), rtol=1e-2, atol=1e-2)
# First, we make sure that adding padding tokens doesn't change the loss
# loss(input_ids, attention_mask=None) == loss(input_ids + padding, attention_mask=attention_mask_with_padding)
pad_length = 1000
# Add padding tokens (assume that pad_token_id=1) to input_ids
padding_block = torch.ones(input_ids.shape[0], pad_length, dtype=torch.int32).to(torch_device)
padded_input_ids = torch.cat((padding_block, input_ids), dim=1) # this is to simulate padding to the left
padded_attention_mask = padded_input_ids.ne(1).to(torch_device)
padded_result = model(padded_input_ids, attention_mask=padded_attention_mask)
torch.testing.assert_close(result.aux_loss.cpu(), padded_result.aux_loss.cpu(), rtol=1e-4, atol=1e-4)
# We make sure that the loss of including padding tokens != the loss without padding tokens
# if attention_mask=None --> we don't exclude padding tokens
include_padding_result = model(padded_input_ids, attention_mask=None)
# This is to mimic torch.testing.assert_not_close
self.assertNotAlmostEqual(include_padding_result.aux_loss.item(), result.aux_loss.item())
def _check_attentions_for_generate(
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
):
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (output_length - prompt_length))
use_cache = decoder_past_key_values is not None
for generated_length, iter_attentions in enumerate(attentions):
# regardless of using cache, the first forward pass will have the full prompt as input
if use_cache and generated_length > 0:
model_input_length = 1
else:
model_input_length = prompt_length + generated_length
expected_shape = (
batch_size,
config.num_attention_heads,
model_input_length,
prompt_length + generated_length,
)
for layer_idx, layer_attention in enumerate(iter_attentions):
if config.layer_types[layer_idx] == "full_attention":
self.assertEqual(layer_attention.shape, expected_shape)
def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config):
self.assertIsInstance(decoder_past_key_values, (tuple, Cache))
# (batch, head, seq_length, head_features)
key_value_cache_expected_shape = (
batch_size,
config.num_key_value_heads,
cache_length,
config.hidden_size // config.num_attention_heads,
)
# (batch, head, head_features, head_features)
linear_cache_expected_shape = (
batch_size,
config.num_attention_heads,
config.hidden_size // config.num_attention_heads,
config.hidden_size // config.num_attention_heads,
)
for layer_idx in range(config.num_hidden_layers):
if config.layer_types[layer_idx] == "full_attention":
self.assertEqual(decoder_past_key_values[layer_idx][0].shape, key_value_cache_expected_shape)
self.assertEqual(decoder_past_key_values[layer_idx][1].shape, key_value_cache_expected_shape)
else:
self.assertEqual(decoder_past_key_values[layer_idx][0].shape, linear_cache_expected_shape)
@pytest.mark.generate
def test_past_key_values_format(self, custom_all_cache_shapes=None):
"""
Test that the KV cache is formatted correctly.
"""
for model_class in self.all_generative_model_classes:
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config).to(torch_device)
model = model.eval()
if "use_cache" not in inputs:
inputs["use_cache"] = True
outputs = model(**inputs)
past_kv = outputs["past_key_values"]
batch_size, seq_length = inputs["input_ids"].shape
self._check_past_key_values_for_generate(batch_size, past_kv, seq_length, config)
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
def test_prompt_lookup_decoding_matches_greedy_search(self):
pass
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
def test_assisted_decoding_sample(self):
pass
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
def test_assisted_decoding_matches_greedy_search_0_random(self):
pass
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
def test_assisted_decoding_matches_greedy_search_1_same(self):
pass
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@require_torch
@require_torch_accelerator
@slow
class MiniMaxIntegrationTest(unittest.TestCase):
def test_small_model_logits(self):
model_id = "geetu040/MiniMax-tiny"
dummy_input = torch.LongTensor([[0, 1, 0], [0, 1, 0]]).to(torch_device)
model = MiniMaxForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to(
torch_device
)
expected_slice = torch.tensor(
[[1.0312, -0.5156, -0.3262], [-0.1152, 0.4336, 0.2412], [1.2188, -0.5898, -0.0381]]
).to(torch_device)
with torch.no_grad():
logits = model(dummy_input).logits
logits = logits.float()
torch.testing.assert_close(logits[0, :3, :3], expected_slice, atol=1e-3, rtol=1e-3)
torch.testing.assert_close(logits[1, :3, :3], expected_slice, atol=1e-3, rtol=1e-3)
def test_small_model_generation(self):
model_id = "geetu040/MiniMax-tiny"
dummy_input = torch.LongTensor([[0, 1, 0], [0, 1, 0]]).to(torch_device)
model = MiniMaxForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to(
torch_device
)
expected_slice = (
torch.tensor([[0, 1, 0, 933, 307, 3102, 2457, 1208], [0, 1, 0, 933, 307, 3102, 2457, 1208]])
.to(torch.int64)
.to(torch_device)
)
outputs = model.generate(dummy_input, max_new_tokens=5, do_sample=False)
torch.testing.assert_close(outputs, expected_slice, atol=1e-3, rtol=1e-3)

View File

@ -3946,7 +3946,7 @@ class ModelTesterMixin:
if not self.has_attentions: if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions") self.skipTest(reason="Model architecture does not support attentions")
WINDOW_ATTENTION_MODELS = ["mistral", "mixtral", "qwen2", "qwen_moe", "starcoder2"] WINDOW_ATTENTION_MODELS = ["mistral", "mixtral", "minimax", "qwen2", "qwen_moe", "starcoder2"]
if len(self.all_generative_model_classes) == 0: if len(self.all_generative_model_classes) == 0:
self.skipTest(f"No generative model classes for {self.__class__.__name__}") self.skipTest(f"No generative model classes for {self.__class__.__name__}")