mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Add support for MiniMax's MiniMax-Text-01 (#35831)
* end-to-end architecture
* lightning-attn: refactor, clean, optimize
* put minimax_text_01 in other files
* use latest __init__ standards and auto-generate modular
* support attention_mask for lightning-attn
* Revert "use latest __init__ standards and auto-generate modular"
This reverts commit d8d3c409d8
.
* fix modular conversion
* pass both attention masks instead of tuple
* formatting
* Updated Dynamic Cache
* created MiniMaxText01Cache
* fix hardcoded slope_rate
* update attn_type_list in config
* fix lightning when use_cache=False
* copy tests from mixtral
* (checkpoint) all tests pass for normal attention
* fix all unittests
* fix import sorting
* fix consistency and formatting tests
* fix config
* update tests, since changes in main
* fix seq_len error
* create dummy docs
* fix checkpoint
* add checkpoint in config docstring
* run modular_conversion
* update docs
* fix checkpoint path and update tests
* fix ruff
* remove repeated expected_slice
* update docs
* rename "minimax-text-01" to "minimax"
* inherit config from mixtral
* remove from docs in other languages
* undo files that should be untouched
* move minimax to end in conversation docs
* use MiniMaxForCausalLM as it is
* ruff fixes
* run modular
* fix docstring example in causallm
* refactor attention loop and decay factors
* refactor config in modular
* run modular
* refactor cache
* rename static_cache to linear_cache
* make positional embeddings necessary
* remove unnecessary layernorms declarations
* fix import in tests
* refactor attention in next tokens
* remove outdated code
* formatting and modular
* update tests
* rename layernorm alpha/beta factors
* register decay factors as buffers
* remove unused declarations of decay factors
* update config for alpha/beta factors
* run modular
* remove head_dim in tests
* remove minimax from fx.py
* remove stuff that is not really needed
* update __init__
* update qkv torch.split
Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
* fix qkv torch.split
* quality fixes
* remove mistakenly added dummy
* purge unused ModelTester code
* fix-copies
* run fix-copies
* fix head_dim
* write cache formatting tests
* remove postnorm
* avoid contiguous in attention current states
* update expected_slice
* add generation test for integration
* fix dtype in generation test
* update authors
* update with changes in main
* update graident checkpointing and minor fixes
* fix mutable attn_type_list
* rename: attn_type -> layer_type
* update for layer_types
* update integration tests
* update checkpoint
* clean overview in docs
---------
Co-authored-by: Shakib-IO <shakib.khan17@northsouth.edu>
Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
This commit is contained in:
parent
037acf1d10
commit
55736eea99
@ -555,6 +555,8 @@
|
|||||||
title: MegatronBERT
|
title: MegatronBERT
|
||||||
- local: model_doc/megatron_gpt2
|
- local: model_doc/megatron_gpt2
|
||||||
title: MegatronGPT2
|
title: MegatronGPT2
|
||||||
|
- local: model_doc/minimax
|
||||||
|
title: MiniMax
|
||||||
- local: model_doc/mistral
|
- local: model_doc/mistral
|
||||||
title: Mistral
|
title: Mistral
|
||||||
- local: model_doc/mixtral
|
- local: model_doc/mixtral
|
||||||
|
189
docs/source/en/model_doc/minimax.md
Normal file
189
docs/source/en/model_doc/minimax.md
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
<!--Copyright 2025 MiniMaxAI and The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
|
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||||
|
rendered properly in your Markdown viewer.
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
# MiniMax
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The DepthPro model was proposed in [MiniMax-01: Scaling Foundation Models with Lightning Attention](https://arxiv.org/abs/2501.08313) by MiniMax, Aonian Li, Bangwei Gong, Bo Yang, Boji Shan, Chang Liu, Cheng Zhu, Chunhao Zhang, Congchao Guo, Da Chen, Dong Li, Enwei Jiao, Gengxin Li, Guojun Zhang, Haohai Sun, Houze Dong, Jiadai Zhu, Jiaqi Zhuang, Jiayuan Song, Jin Zhu, Jingtao Han, Jingyang Li, Junbin Xie, Junhao Xu, Junjie Yan, Kaishun Zhang, Kecheng Xiao, Kexi Kang, Le Han, Leyang Wang, Lianfei Yu, Liheng Feng, Lin Zheng, Linbo Chai, Long Xing, Meizhi Ju, Mingyuan Chi, Mozhi Zhang, Peikai Huang, Pengcheng Niu, Pengfei Li, Pengyu Zhao, Qi Yang, Qidi Xu, Qiexiang Wang, Qin Wang, Qiuhui Li, Ruitao Leng, Shengmin Shi, Shuqi Yu, Sichen Li, Songquan Zhu, Tao Huang, Tianrun Liang, Weigao Sun, Weixuan Sun, Weiyu Cheng, Wenkai Li, Xiangjun Song, Xiao Su, Xiaodong Han, Xinjie Zhang, Xinzhu Hou, Xu Min, Xun Zou, Xuyang Shen, Yan Gong, Yingjie Zhu, Yipeng Zhou, Yiran Zhong, Yongyi Hu, Yuanxiang Fan, Yue Yu, Yufeng Yang, Yuhao Li, Yunan Huang, Yunji Li, Yunpeng Huang, Yunzhi Xu, Yuxin Mao, Zehan Li, Zekang Li, Zewei Tao, Zewen Ying, Zhaoyang Cong, Zhen Qin, Zhenhua Fan, Zhihang Yu, Zhuo Jiang, Zijia Wu.
|
||||||
|
|
||||||
|
The abstract from the paper is the following:
|
||||||
|
|
||||||
|
*We introduce MiniMax-01 series, including MiniMax-Text-01 and MiniMax-VL-01, which are comparable to top-tier models while offering superior capabilities in processing longer contexts. The core lies in lightning attention and its efficient scaling. To maximize computational capacity, we integrate it with Mixture of Experts (MoE), creating a model with 32 experts and 456 billion total parameters, of which 45.9 billion are activated for each token. We develop an optimized parallel strategy and highly efficient computation-communication overlap techniques for MoE and lightning attention. This approach enables us to conduct efficient training and inference on models with hundreds of billions of parameters across contexts spanning millions of tokens. The context window of MiniMax-Text-01 can reach up to 1 million tokens during training and extrapolate to 4 million tokens during inference at an affordable cost. Our vision-language model, MiniMax-VL-01 is built through continued training with 512 billion vision-language tokens. Experiments on both standard and in-house benchmarks show that our models match the performance of state-of-the-art models like GPT-4o and Claude-3.5-Sonnet while offering 20-32 times longer context window.*
|
||||||
|
|
||||||
|
### Architectural details
|
||||||
|
|
||||||
|
MiniMax is a powerful language model with 456 billion total parameters, of which 45.9 billion are activated per token. To better unlock the long context capabilities of the model, MiniMax adopts a hybrid architecture that combines Lightning Attention, Softmax Attention and Mixture-of-Experts (MoE). Leveraging advanced parallel strategies and innovative compute-communication overlap methods—such as Linear Attention Sequence Parallelism Plus (LASP+), varlen ring attention, Expert Tensor Parallel (ETP), etc., MiniMax's training context length is extended to 1 million tokens, and it can handle a context of up to 4 million tokens during the inference. On various academic benchmarks, MiniMax also demonstrates the performance of a top-tier model.
|
||||||
|
|
||||||
|
The architecture of MiniMax is briefly described as follows:
|
||||||
|
|
||||||
|
- Total Parameters: 456B
|
||||||
|
- Activated Parameters per Token: 45.9B
|
||||||
|
- Number Layers: 80
|
||||||
|
- Hybrid Attention: a softmax attention is positioned after every 7 lightning attention.
|
||||||
|
- Number of attention heads: 64
|
||||||
|
- Attention head dimension: 128
|
||||||
|
- Mixture of Experts:
|
||||||
|
- Number of experts: 32
|
||||||
|
- Expert hidden dimension: 9216
|
||||||
|
- Top-2 routing strategy
|
||||||
|
- Positional Encoding: Rotary Position Embedding (RoPE) applied to half of the attention head dimension with a base frequency of 10,000,000
|
||||||
|
- Hidden Size: 6144
|
||||||
|
- Vocab Size: 200,064
|
||||||
|
|
||||||
|
For more details refer to the [release blog post](https://www.minimaxi.com/en/news/minimax-01-series-2).
|
||||||
|
|
||||||
|
### License
|
||||||
|
|
||||||
|
`MiniMax` is released under the MINIMAX MODEL LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
## Usage tips
|
||||||
|
|
||||||
|
The pre-trained model can be used as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf", device_map="auto")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
|
||||||
|
|
||||||
|
>>> messages = [
|
||||||
|
... {"role": "user", "content": "What is your favourite condiment?"},
|
||||||
|
... {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
|
||||||
|
... {"role": "user", "content": "Do you have mayonnaise recipes?"}
|
||||||
|
... ]
|
||||||
|
|
||||||
|
>>> model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")
|
||||||
|
|
||||||
|
>>> generated_ids = model.generate(model_inputs, max_new_tokens=100, do_sample=True)
|
||||||
|
>>> tokenizer.batch_decode(generated_ids)[0]
|
||||||
|
"Mayonnaise can be made as follows: (...)"
|
||||||
|
```
|
||||||
|
|
||||||
|
As can be seen, the instruction-tuned model requires a [chat template](../chat_templating) to be applied to make sure the inputs are prepared in the right format.
|
||||||
|
|
||||||
|
## Speeding up MiniMax by using Flash Attention
|
||||||
|
|
||||||
|
The code snippets above showcase inference without any optimization tricks. However, one can drastically speed up the model by leveraging [Flash Attention](../perf_train_gpu_one#flash-attention-2), which is a faster implementation of the attention mechanism used inside the model.
|
||||||
|
|
||||||
|
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -U flash-attn --no-build-isolation
|
||||||
|
```
|
||||||
|
|
||||||
|
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of the [flash attention repository](https://github.com/Dao-AILab/flash-attention). Make also sure to load your model in half-precision (e.g. `torch.float16`)
|
||||||
|
|
||||||
|
To load and run a model using Flash Attention-2, refer to the snippet below:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf", torch_dtype=torch.float16, attn_implementation="flash_attention_2", device_map="auto")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
|
||||||
|
|
||||||
|
>>> prompt = "My favourite condiment is"
|
||||||
|
|
||||||
|
>>> model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
|
||||||
|
>>> model.to(device)
|
||||||
|
|
||||||
|
>>> generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
|
||||||
|
>>> tokenizer.batch_decode(generated_ids)[0]
|
||||||
|
"The expected output"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Sliding window Attention
|
||||||
|
|
||||||
|
The current implementation supports the sliding window attention mechanism and memory efficient cache management.
|
||||||
|
To enable sliding window attention, just make sure to have a `flash-attn` version that is compatible with sliding window attention (`>=2.3.0`).
|
||||||
|
|
||||||
|
The Flash Attention-2 model uses also a more memory efficient cache slicing mechanism - as recommended per the official implementation of Mistral model that use rolling cache mechanism we keep the cache size fixed (`self.config.sliding_window`), support batched generation only for `padding_side="left"` and use the absolute position of the current token to compute the positional embedding.
|
||||||
|
|
||||||
|
## Shrinking down MiniMax using quantization
|
||||||
|
|
||||||
|
As the MiniMax model has 456 billion parameters, that would require about 912GB of GPU RAM in half precision (float16), since each parameter is stored in 2 bytes. However, one can shrink down the size of the model using [quantization](../quantization.md). If the model is quantized to 4 bits (or half a byte per parameter), about 228 GB of RAM is required.
|
||||||
|
|
||||||
|
Quantizing a model is as simple as passing a `quantization_config` to the model. Below, we'll leverage the bitsandbytes quantization library (but refer to [this page](../quantization.md) for alternative quantization methods):
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
||||||
|
|
||||||
|
>>> # specify how to quantize the model
|
||||||
|
>>> quantization_config = BitsAndBytesConfig(
|
||||||
|
... load_in_4bit=True,
|
||||||
|
... bnb_4bit_quant_type="nf4",
|
||||||
|
... bnb_4bit_compute_dtype="torch.float16",
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf", quantization_config=True, device_map="auto")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
|
||||||
|
|
||||||
|
>>> prompt = "My favourite condiment is"
|
||||||
|
|
||||||
|
>>> messages = [
|
||||||
|
... {"role": "user", "content": "What is your favourite condiment?"},
|
||||||
|
... {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
|
||||||
|
... {"role": "user", "content": "Do you have mayonnaise recipes?"}
|
||||||
|
... ]
|
||||||
|
|
||||||
|
>>> model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")
|
||||||
|
|
||||||
|
>>> generated_ids = model.generate(model_inputs, max_new_tokens=100, do_sample=True)
|
||||||
|
>>> tokenizer.batch_decode(generated_ids)[0]
|
||||||
|
"The expected output"
|
||||||
|
```
|
||||||
|
|
||||||
|
This model was contributed by [geetu040](https://github.com/geetu040).
|
||||||
|
The original code can be found [here](https://huggingface.co/MiniMaxAI/MiniMax-Text-01-hf/blob/main/modeling_minimax.py).
|
||||||
|
|
||||||
|
## Resources
|
||||||
|
|
||||||
|
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with MiniMax. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||||
|
|
||||||
|
<PipelineTag pipeline="text-generation"/>
|
||||||
|
|
||||||
|
- The [Alignment Handbook](https://github.com/huggingface/alignment-handbook) by Hugging Face includes scripts and recipes to perform supervised fine-tuning (SFT) and direct preference optimization with Mistral-7B. This includes scripts for full fine-tuning, QLoRa on a single GPU as well as multi-GPU fine-tuning.
|
||||||
|
- [Causal language modeling task guide](../tasks/language_modeling)
|
||||||
|
|
||||||
|
## MiniMaxConfig
|
||||||
|
|
||||||
|
[[autodoc]] MiniMaxConfig
|
||||||
|
|
||||||
|
## MiniMaxModel
|
||||||
|
|
||||||
|
[[autodoc]] MiniMaxModel
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## MiniMaxForCausalLM
|
||||||
|
|
||||||
|
[[autodoc]] MiniMaxForCausalLM
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## MiniMaxForSequenceClassification
|
||||||
|
|
||||||
|
[[autodoc]] MiniMaxForSequenceClassification
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## MiniMaxForTokenClassification
|
||||||
|
|
||||||
|
[[autodoc]] MiniMaxForTokenClassification
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## MiniMaxForQuestionAnswering
|
||||||
|
[[autodoc]] MiniMaxForQuestionAnswering
|
||||||
|
- forward
|
@ -1218,6 +1218,7 @@ ALLOWED_LAYER_TYPES = (
|
|||||||
"full_attention",
|
"full_attention",
|
||||||
"sliding_attention",
|
"sliding_attention",
|
||||||
"chunked_attention",
|
"chunked_attention",
|
||||||
|
"linear_attention", # used in minimax
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1976,6 +1976,7 @@ class GenerationMixin(ContinuousMixin):
|
|||||||
and "jamba" not in self.__class__.__name__.lower()
|
and "jamba" not in self.__class__.__name__.lower()
|
||||||
and "zamba" not in self.__class__.__name__.lower()
|
and "zamba" not in self.__class__.__name__.lower()
|
||||||
and "bamba" not in self.__class__.__name__.lower()
|
and "bamba" not in self.__class__.__name__.lower()
|
||||||
|
and "minimax" not in self.__class__.__name__.lower()
|
||||||
)
|
)
|
||||||
|
|
||||||
def _prepare_cache_for_generation(
|
def _prepare_cache_for_generation(
|
||||||
|
@ -185,6 +185,7 @@ if TYPE_CHECKING:
|
|||||||
from .megatron_gpt2 import *
|
from .megatron_gpt2 import *
|
||||||
from .mgp_str import *
|
from .mgp_str import *
|
||||||
from .mimi import *
|
from .mimi import *
|
||||||
|
from .minimax import *
|
||||||
from .mistral import *
|
from .mistral import *
|
||||||
from .mistral3 import *
|
from .mistral3 import *
|
||||||
from .mixtral import *
|
from .mixtral import *
|
||||||
|
@ -211,6 +211,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
|||||||
("megatron-bert", "MegatronBertConfig"),
|
("megatron-bert", "MegatronBertConfig"),
|
||||||
("mgp-str", "MgpstrConfig"),
|
("mgp-str", "MgpstrConfig"),
|
||||||
("mimi", "MimiConfig"),
|
("mimi", "MimiConfig"),
|
||||||
|
("minimax", "MiniMaxConfig"),
|
||||||
("mistral", "MistralConfig"),
|
("mistral", "MistralConfig"),
|
||||||
("mistral3", "Mistral3Config"),
|
("mistral3", "Mistral3Config"),
|
||||||
("mixtral", "MixtralConfig"),
|
("mixtral", "MixtralConfig"),
|
||||||
@ -586,6 +587,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
|||||||
("megatron_gpt2", "Megatron-GPT2"),
|
("megatron_gpt2", "Megatron-GPT2"),
|
||||||
("mgp-str", "MGP-STR"),
|
("mgp-str", "MGP-STR"),
|
||||||
("mimi", "Mimi"),
|
("mimi", "Mimi"),
|
||||||
|
("minimax", "MiniMax"),
|
||||||
("mistral", "Mistral"),
|
("mistral", "Mistral"),
|
||||||
("mistral3", "Mistral3"),
|
("mistral3", "Mistral3"),
|
||||||
("mixtral", "Mixtral"),
|
("mixtral", "Mixtral"),
|
||||||
|
@ -201,6 +201,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("megatron-bert", "MegatronBertModel"),
|
("megatron-bert", "MegatronBertModel"),
|
||||||
("mgp-str", "MgpstrForSceneTextRecognition"),
|
("mgp-str", "MgpstrForSceneTextRecognition"),
|
||||||
("mimi", "MimiModel"),
|
("mimi", "MimiModel"),
|
||||||
|
("minimax", "MiniMaxModel"),
|
||||||
("mistral", "MistralModel"),
|
("mistral", "MistralModel"),
|
||||||
("mistral3", "Mistral3Model"),
|
("mistral3", "Mistral3Model"),
|
||||||
("mixtral", "MixtralModel"),
|
("mixtral", "MixtralModel"),
|
||||||
@ -594,6 +595,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
|||||||
("mbart", "MBartForCausalLM"),
|
("mbart", "MBartForCausalLM"),
|
||||||
("mega", "MegaForCausalLM"),
|
("mega", "MegaForCausalLM"),
|
||||||
("megatron-bert", "MegatronBertForCausalLM"),
|
("megatron-bert", "MegatronBertForCausalLM"),
|
||||||
|
("minimax", "MiniMaxForCausalLM"),
|
||||||
("mistral", "MistralForCausalLM"),
|
("mistral", "MistralForCausalLM"),
|
||||||
("mixtral", "MixtralForCausalLM"),
|
("mixtral", "MixtralForCausalLM"),
|
||||||
("mllama", "MllamaForCausalLM"),
|
("mllama", "MllamaForCausalLM"),
|
||||||
@ -1106,6 +1108,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
("mbart", "MBartForSequenceClassification"),
|
("mbart", "MBartForSequenceClassification"),
|
||||||
("mega", "MegaForSequenceClassification"),
|
("mega", "MegaForSequenceClassification"),
|
||||||
("megatron-bert", "MegatronBertForSequenceClassification"),
|
("megatron-bert", "MegatronBertForSequenceClassification"),
|
||||||
|
("minimax", "MiniMaxForSequenceClassification"),
|
||||||
("mistral", "MistralForSequenceClassification"),
|
("mistral", "MistralForSequenceClassification"),
|
||||||
("mixtral", "MixtralForSequenceClassification"),
|
("mixtral", "MixtralForSequenceClassification"),
|
||||||
("mobilebert", "MobileBertForSequenceClassification"),
|
("mobilebert", "MobileBertForSequenceClassification"),
|
||||||
@ -1197,6 +1200,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
|||||||
("mbart", "MBartForQuestionAnswering"),
|
("mbart", "MBartForQuestionAnswering"),
|
||||||
("mega", "MegaForQuestionAnswering"),
|
("mega", "MegaForQuestionAnswering"),
|
||||||
("megatron-bert", "MegatronBertForQuestionAnswering"),
|
("megatron-bert", "MegatronBertForQuestionAnswering"),
|
||||||
|
("minimax", "MiniMaxForQuestionAnswering"),
|
||||||
("mistral", "MistralForQuestionAnswering"),
|
("mistral", "MistralForQuestionAnswering"),
|
||||||
("mixtral", "MixtralForQuestionAnswering"),
|
("mixtral", "MixtralForQuestionAnswering"),
|
||||||
("mobilebert", "MobileBertForQuestionAnswering"),
|
("mobilebert", "MobileBertForQuestionAnswering"),
|
||||||
@ -1303,6 +1307,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
("markuplm", "MarkupLMForTokenClassification"),
|
("markuplm", "MarkupLMForTokenClassification"),
|
||||||
("mega", "MegaForTokenClassification"),
|
("mega", "MegaForTokenClassification"),
|
||||||
("megatron-bert", "MegatronBertForTokenClassification"),
|
("megatron-bert", "MegatronBertForTokenClassification"),
|
||||||
|
("minimax", "MiniMaxForTokenClassification"),
|
||||||
("mistral", "MistralForTokenClassification"),
|
("mistral", "MistralForTokenClassification"),
|
||||||
("mixtral", "MixtralForTokenClassification"),
|
("mixtral", "MixtralForTokenClassification"),
|
||||||
("mobilebert", "MobileBertForTokenClassification"),
|
("mobilebert", "MobileBertForTokenClassification"),
|
||||||
|
@ -342,6 +342,13 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
|||||||
("mega", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
("mega", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("mgp-str", ("MgpstrTokenizer", None)),
|
("mgp-str", ("MgpstrTokenizer", None)),
|
||||||
|
(
|
||||||
|
"minimax",
|
||||||
|
(
|
||||||
|
"GPT2Tokenizer" if is_sentencepiece_available() else None,
|
||||||
|
"GPT2TokenizerFast" if is_tokenizers_available() else None,
|
||||||
|
),
|
||||||
|
),
|
||||||
(
|
(
|
||||||
"mistral",
|
"mistral",
|
||||||
(
|
(
|
||||||
|
29
src/transformers/models/minimax/__init__.py
Normal file
29
src/transformers/models/minimax/__init__.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 MiniMaxAI and HuggingFace Inc. teams. All rights reserved.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...utils import _LazyModule
|
||||||
|
from ...utils.import_utils import define_import_structure
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .configuration_minimax import *
|
||||||
|
from .modeling_minimax import *
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
_file = globals()["__file__"]
|
||||||
|
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
230
src/transformers/models/minimax/configuration_minimax.py
Normal file
230
src/transformers/models/minimax/configuration_minimax.py
Normal file
@ -0,0 +1,230 @@
|
|||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# This file was automatically generated from src/transformers/models/minimax/modular_minimax.py.
|
||||||
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
|
# modular_minimax.py file directly. One of our CI enforces this.
|
||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 MiniMaxAI and HuggingFace Inc. teams. All rights reserved.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from ...configuration_utils import PretrainedConfig, layer_type_validation
|
||||||
|
|
||||||
|
|
||||||
|
class MiniMaxConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`MiniMaxModel`]. It is used to instantiate an
|
||||||
|
MiniMax model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||||
|
with the defaults will yield a similar configuration to that of the MiniMax.
|
||||||
|
|
||||||
|
[MiniMaxAI/MiniMax-Text-01-hf](https://huggingface.co/MiniMaxAI/MiniMax-Text-01-hf)
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 32000):
|
||||||
|
Vocabulary size of the MiniMax model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids` passed when calling [`MiniMaxModel`]
|
||||||
|
hidden_size (`int`, *optional*, defaults to 4096):
|
||||||
|
Dimension of the hidden representations.
|
||||||
|
intermediate_size (`int`, *optional*, defaults to 14336):
|
||||||
|
Dimension of the MLP representations.
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||||
|
Number of hidden layers in the Transformer encoder.
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||||
|
Number of attention heads for each attention layer in the Transformer encoder.
|
||||||
|
num_key_value_heads (`int`, *optional*, defaults to 8):
|
||||||
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||||
|
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||||
|
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||||
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||||
|
by meanpooling all the original heads within that group. For more details checkout [this
|
||||||
|
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
|
||||||
|
head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
|
||||||
|
The attention head dimension.
|
||||||
|
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||||
|
The non-linear activation function (function or string) in the decoder.
|
||||||
|
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
|
||||||
|
The maximum sequence length that this model might ever be used with. MiniMax's sliding window attention
|
||||||
|
allows sequence of up to 4096*32 tokens.
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||||
|
The epsilon used by the rms normalization layers.
|
||||||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||||
|
relevant if `config.is_decoder=True`.
|
||||||
|
pad_token_id (`int`, *optional*):
|
||||||
|
The id of the padding token.
|
||||||
|
bos_token_id (`int`, *optional*, defaults to 1):
|
||||||
|
The id of the "beginning-of-sequence" token.
|
||||||
|
eos_token_id (`int`, *optional*, defaults to 2):
|
||||||
|
The id of the "end-of-sequence" token.
|
||||||
|
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether the model's input and output word embeddings should be tied.
|
||||||
|
rope_theta (`float`, *optional*, defaults to 1000000.0):
|
||||||
|
The base period of the RoPE embeddings.
|
||||||
|
sliding_window (`int`, *optional*):
|
||||||
|
Sliding window attention window size. If not specified, will default to `4096`.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
num_experts_per_tok (`int`, *optional*, defaults to 2):
|
||||||
|
The number of experts to route per-token, can be also interpreted as the `top-k` routing
|
||||||
|
parameter
|
||||||
|
num_local_experts (`int`, *optional*, defaults to 8):
|
||||||
|
Number of experts per Sparse MLP layer.
|
||||||
|
output_router_logits (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not the router logits should be returned by the model. Enabeling this will also
|
||||||
|
allow the model to output the auxiliary loss. See [here]() for more details
|
||||||
|
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
|
||||||
|
The aux loss factor for the total loss.
|
||||||
|
router_jitter_noise (`float`, *optional*, defaults to 0.0):
|
||||||
|
Amount of noise to add to the router.
|
||||||
|
layer_types (`list`, *optional*):
|
||||||
|
Attention pattern for each layer.
|
||||||
|
block_size (`int`, *optional*, defaults to 256):
|
||||||
|
The length of each attention block, determining how queries, keys, and values
|
||||||
|
are grouped and processed for intra- and inter-block attention.
|
||||||
|
full_attn_alpha_factor (`float`, *optional*, defaults to 1):
|
||||||
|
Weight for residual value in residual connection after normal attention.
|
||||||
|
full_attn_beta_factor (`float`, *optional*, defaults to 1):
|
||||||
|
Weight for hidden state value in residual connection after normal attention.
|
||||||
|
linear_attn_alpha_factor (`float`, *optional*, defaults to 1):
|
||||||
|
Weight for residual value in residual connection after lightning attention.
|
||||||
|
linear_attn_beta_factor (`float`, *optional*, defaults to 1):
|
||||||
|
Weight for hidden state value in residual connection after lightning attention.
|
||||||
|
mlp_alpha_factor (`float`, *optional*, defaults to 1):
|
||||||
|
Weight for residual value in residual connection after MLP.
|
||||||
|
mlp_beta_factor (`float`, *optional*, defaults to 1):
|
||||||
|
Weight for hidden state value in residual connection after MLP.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import MiniMaxModel, MiniMaxConfig
|
||||||
|
|
||||||
|
>>> # Initializing a MiniMax style configuration
|
||||||
|
>>> configuration = MiniMaxConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a model from the MiniMax style configuration
|
||||||
|
>>> model = MiniMaxModel(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "minimax"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
base_model_tp_plan = {
|
||||||
|
"layers.*.self_attn.q_proj": "colwise",
|
||||||
|
"layers.*.self_attn.k_proj": "colwise",
|
||||||
|
"layers.*.self_attn.v_proj": "colwise",
|
||||||
|
"layers.*.self_attn.o_proj": "rowwise",
|
||||||
|
"layers.*.block_sparse_moe.gate": "colwise_rep", # we need to replicate here to correctly route experts
|
||||||
|
"layers.*.block_sparse_moe.experts.*.w1": "colwise",
|
||||||
|
"layers.*.block_sparse_moe.experts.*.w2": "rowwise",
|
||||||
|
"layers.*.block_sparse_moe.experts.*.w3": "colwise",
|
||||||
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=32000,
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=14336,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=8,
|
||||||
|
head_dim=None,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=4096 * 32,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-5,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=None,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=1e6,
|
||||||
|
sliding_window=None,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
num_experts_per_tok=2,
|
||||||
|
num_local_experts=8,
|
||||||
|
output_router_logits=False,
|
||||||
|
router_aux_loss_coef=0.001,
|
||||||
|
router_jitter_noise=0.0,
|
||||||
|
layer_types=None,
|
||||||
|
block_size=256,
|
||||||
|
full_attn_alpha_factor=1,
|
||||||
|
full_attn_beta_factor=1,
|
||||||
|
linear_attn_alpha_factor=1,
|
||||||
|
linear_attn_beta_factor=1,
|
||||||
|
mlp_alpha_factor=1,
|
||||||
|
mlp_beta_factor=1,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.head_dim = head_dim
|
||||||
|
|
||||||
|
self.num_experts_per_tok = num_experts_per_tok
|
||||||
|
self.num_local_experts = num_local_experts
|
||||||
|
self.output_router_logits = output_router_logits
|
||||||
|
self.router_aux_loss_coef = router_aux_loss_coef
|
||||||
|
self.router_jitter_noise = router_jitter_noise
|
||||||
|
self.layer_types = layer_types
|
||||||
|
self.block_size = block_size
|
||||||
|
self.full_attn_alpha_factor = full_attn_alpha_factor
|
||||||
|
self.full_attn_beta_factor = full_attn_beta_factor
|
||||||
|
self.linear_attn_alpha_factor = linear_attn_alpha_factor
|
||||||
|
self.linear_attn_beta_factor = linear_attn_beta_factor
|
||||||
|
self.mlp_alpha_factor = mlp_alpha_factor
|
||||||
|
self.mlp_beta_factor = mlp_beta_factor
|
||||||
|
|
||||||
|
if self.layer_types is None:
|
||||||
|
self.layer_types = [
|
||||||
|
"full_attention" if bool((i + 1) % 2) else "linear_attention" for i in range(self.num_hidden_layers)
|
||||||
|
]
|
||||||
|
layer_type_validation(self.layer_types)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["MiniMaxConfig"]
|
1259
src/transformers/models/minimax/modeling_minimax.py
Normal file
1259
src/transformers/models/minimax/modeling_minimax.py
Normal file
File diff suppressed because it is too large
Load Diff
644
src/transformers/models/minimax/modular_minimax.py
Normal file
644
src/transformers/models/minimax/modular_minimax.py
Normal file
@ -0,0 +1,644 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 MiniMaxAI and HuggingFace Inc. teams. All rights reserved.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""PyTorch MiniMax model."""
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from ...activations import ACT2FN
|
||||||
|
from ...cache_utils import Cache, DynamicCache
|
||||||
|
from ...configuration_utils import layer_type_validation
|
||||||
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
||||||
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
|
from ...modeling_outputs import MoeModelOutputWithPast
|
||||||
|
from ...processing_utils import Unpack
|
||||||
|
from ...utils import logging
|
||||||
|
from ..mixtral.configuration_mixtral import MixtralConfig
|
||||||
|
from ..mixtral.modeling_mixtral import (
|
||||||
|
MixtralAttention,
|
||||||
|
MixtralDecoderLayer,
|
||||||
|
MixtralForCausalLM,
|
||||||
|
MixtralForQuestionAnswering,
|
||||||
|
MixtralForSequenceClassification,
|
||||||
|
MixtralForTokenClassification,
|
||||||
|
MixtralModel,
|
||||||
|
MixtralPreTrainedModel,
|
||||||
|
MixtralRMSNorm,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MiniMaxConfig(MixtralConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`MiniMaxModel`]. It is used to instantiate an
|
||||||
|
MiniMax model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||||
|
with the defaults will yield a similar configuration to that of the MiniMax.
|
||||||
|
|
||||||
|
[MiniMaxAI/MiniMax-Text-01-hf](https://huggingface.co/MiniMaxAI/MiniMax-Text-01-hf)
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 32000):
|
||||||
|
Vocabulary size of the MiniMax model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids` passed when calling [`MiniMaxModel`]
|
||||||
|
hidden_size (`int`, *optional*, defaults to 4096):
|
||||||
|
Dimension of the hidden representations.
|
||||||
|
intermediate_size (`int`, *optional*, defaults to 14336):
|
||||||
|
Dimension of the MLP representations.
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||||
|
Number of hidden layers in the Transformer encoder.
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||||
|
Number of attention heads for each attention layer in the Transformer encoder.
|
||||||
|
num_key_value_heads (`int`, *optional*, defaults to 8):
|
||||||
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||||
|
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||||
|
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||||
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||||
|
by meanpooling all the original heads within that group. For more details checkout [this
|
||||||
|
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
|
||||||
|
head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
|
||||||
|
The attention head dimension.
|
||||||
|
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||||
|
The non-linear activation function (function or string) in the decoder.
|
||||||
|
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
|
||||||
|
The maximum sequence length that this model might ever be used with. MiniMax's sliding window attention
|
||||||
|
allows sequence of up to 4096*32 tokens.
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||||
|
The epsilon used by the rms normalization layers.
|
||||||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||||
|
relevant if `config.is_decoder=True`.
|
||||||
|
pad_token_id (`int`, *optional*):
|
||||||
|
The id of the padding token.
|
||||||
|
bos_token_id (`int`, *optional*, defaults to 1):
|
||||||
|
The id of the "beginning-of-sequence" token.
|
||||||
|
eos_token_id (`int`, *optional*, defaults to 2):
|
||||||
|
The id of the "end-of-sequence" token.
|
||||||
|
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether the model's input and output word embeddings should be tied.
|
||||||
|
rope_theta (`float`, *optional*, defaults to 1000000.0):
|
||||||
|
The base period of the RoPE embeddings.
|
||||||
|
sliding_window (`int`, *optional*):
|
||||||
|
Sliding window attention window size. If not specified, will default to `4096`.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
num_experts_per_tok (`int`, *optional*, defaults to 2):
|
||||||
|
The number of experts to route per-token, can be also interpreted as the `top-k` routing
|
||||||
|
parameter
|
||||||
|
num_local_experts (`int`, *optional*, defaults to 8):
|
||||||
|
Number of experts per Sparse MLP layer.
|
||||||
|
output_router_logits (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not the router logits should be returned by the model. Enabeling this will also
|
||||||
|
allow the model to output the auxiliary loss. See [here]() for more details
|
||||||
|
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
|
||||||
|
The aux loss factor for the total loss.
|
||||||
|
router_jitter_noise (`float`, *optional*, defaults to 0.0):
|
||||||
|
Amount of noise to add to the router.
|
||||||
|
layer_types (`list`, *optional*):
|
||||||
|
Attention pattern for each layer.
|
||||||
|
block_size (`int`, *optional*, defaults to 256):
|
||||||
|
The length of each attention block, determining how queries, keys, and values
|
||||||
|
are grouped and processed for intra- and inter-block attention.
|
||||||
|
full_attn_alpha_factor (`float`, *optional*, defaults to 1):
|
||||||
|
Weight for residual value in residual connection after normal attention.
|
||||||
|
full_attn_beta_factor (`float`, *optional*, defaults to 1):
|
||||||
|
Weight for hidden state value in residual connection after normal attention.
|
||||||
|
linear_attn_alpha_factor (`float`, *optional*, defaults to 1):
|
||||||
|
Weight for residual value in residual connection after lightning attention.
|
||||||
|
linear_attn_beta_factor (`float`, *optional*, defaults to 1):
|
||||||
|
Weight for hidden state value in residual connection after lightning attention.
|
||||||
|
mlp_alpha_factor (`float`, *optional*, defaults to 1):
|
||||||
|
Weight for residual value in residual connection after MLP.
|
||||||
|
mlp_beta_factor (`float`, *optional*, defaults to 1):
|
||||||
|
Weight for hidden state value in residual connection after MLP.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import MiniMaxModel, MiniMaxConfig
|
||||||
|
|
||||||
|
>>> # Initializing a MiniMax style configuration
|
||||||
|
>>> configuration = MiniMaxConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a model from the MiniMax style configuration
|
||||||
|
>>> model = MiniMaxModel(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_types=None,
|
||||||
|
block_size=256,
|
||||||
|
full_attn_alpha_factor=1,
|
||||||
|
full_attn_beta_factor=1,
|
||||||
|
linear_attn_alpha_factor=1,
|
||||||
|
linear_attn_beta_factor=1,
|
||||||
|
mlp_alpha_factor=1,
|
||||||
|
mlp_beta_factor=1,
|
||||||
|
**super_kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**super_kwargs)
|
||||||
|
self.layer_types = layer_types
|
||||||
|
self.block_size = block_size
|
||||||
|
self.full_attn_alpha_factor = full_attn_alpha_factor
|
||||||
|
self.full_attn_beta_factor = full_attn_beta_factor
|
||||||
|
self.linear_attn_alpha_factor = linear_attn_alpha_factor
|
||||||
|
self.linear_attn_beta_factor = linear_attn_beta_factor
|
||||||
|
self.mlp_alpha_factor = mlp_alpha_factor
|
||||||
|
self.mlp_beta_factor = mlp_beta_factor
|
||||||
|
|
||||||
|
if self.layer_types is None:
|
||||||
|
self.layer_types = [
|
||||||
|
"full_attention" if bool((i + 1) % 2) else "linear_attention" for i in range(self.num_hidden_layers)
|
||||||
|
]
|
||||||
|
layer_type_validation(self.layer_types)
|
||||||
|
|
||||||
|
|
||||||
|
class MiniMaxRMSNorm(MixtralRMSNorm):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MiniMaxCache(DynamicCache):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear_cache: List[torch.Tensor] = []
|
||||||
|
|
||||||
|
def set_linear_cache(self, layer_idx, linear_cache):
|
||||||
|
# There may be skipped layers, fill them with empty lists
|
||||||
|
for _ in range(len(self.linear_cache), layer_idx + 1):
|
||||||
|
self.linear_cache.append([])
|
||||||
|
self.linear_cache[layer_idx] = linear_cache
|
||||||
|
|
||||||
|
def get_linear_cache(self, layer_idx: int):
|
||||||
|
if layer_idx < len(self):
|
||||||
|
return self.linear_cache[layer_idx]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return max(super().__len__(), len(self.linear_cache))
|
||||||
|
|
||||||
|
def __getitem__(self, layer_idx: int):
|
||||||
|
if layer_idx < len(self.linear_cache) and self.linear_cache[layer_idx] != []:
|
||||||
|
return (self.linear_cache[layer_idx],)
|
||||||
|
return super().__getitem__(layer_idx)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for layer_idx in range(len(self)):
|
||||||
|
yield self[layer_idx]
|
||||||
|
|
||||||
|
def batch_repeat_interleave(self, repeats: int):
|
||||||
|
for layer_idx in range(len(self)):
|
||||||
|
if self.linear_cache[layer_idx] != []:
|
||||||
|
self.linear_cache[layer_idx] = self.linear_cache[layer_idx].repeat_interleave(repeats, dim=0)
|
||||||
|
else:
|
||||||
|
self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
|
||||||
|
self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)
|
||||||
|
|
||||||
|
def batch_select_indices(self, indices: torch.Tensor):
|
||||||
|
for layer_idx in range(len(self)):
|
||||||
|
if self.linear_cache[layer_idx] != []:
|
||||||
|
self.linear_cache[layer_idx] = self.linear_cache[layer_idx][indices, ...]
|
||||||
|
else:
|
||||||
|
self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
|
||||||
|
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
|
||||||
|
|
||||||
|
def crop(self, max_length: int):
|
||||||
|
raise RuntimeError("MiniMaxCache doesnot support `crop` method")
|
||||||
|
|
||||||
|
|
||||||
|
class MiniMaxLightningAttention(nn.Module):
|
||||||
|
def __init__(self, config: MiniMaxConfig, layer_idx: int):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
||||||
|
self.num_attention_heads = config.num_attention_heads
|
||||||
|
self.num_hidden_layers = config.num_hidden_layers
|
||||||
|
self.block_size = config.block_size
|
||||||
|
|
||||||
|
self.act_fn = ACT2FN[config.hidden_act]
|
||||||
|
self.norm = MiniMaxRMSNorm(self.head_dim * self.num_attention_heads)
|
||||||
|
self.qkv_proj = nn.Linear(config.hidden_size, self.num_attention_heads * self.head_dim * 3, bias=False)
|
||||||
|
self.out_proj = nn.Linear(self.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
|
||||||
|
self.output_gate = nn.Linear(config.hidden_size, self.num_attention_heads * self.head_dim, bias=False)
|
||||||
|
|
||||||
|
slope_rate = self.get_slope_rate()
|
||||||
|
query_decay, key_decay, diagonal_decay = self.decay_factors(slope_rate)
|
||||||
|
|
||||||
|
self.register_buffer("slope_rate", slope_rate)
|
||||||
|
self.register_buffer("query_decay", query_decay)
|
||||||
|
self.register_buffer("key_decay", key_decay)
|
||||||
|
self.register_buffer("diagonal_decay", diagonal_decay)
|
||||||
|
|
||||||
|
def get_slope_rate(self):
|
||||||
|
base = 1 / (2 ** (8 / self.num_attention_heads))
|
||||||
|
exponent = torch.arange(self.num_attention_heads) + 1
|
||||||
|
factor = 1 - self.layer_idx / (self.num_hidden_layers - 1 + 1e-5) + 1e-5
|
||||||
|
|
||||||
|
rate = base**exponent
|
||||||
|
rate = rate * factor
|
||||||
|
rate = rate[:, None, None]
|
||||||
|
|
||||||
|
return rate
|
||||||
|
|
||||||
|
def decay_factors(self, slope_rate):
|
||||||
|
block_size_range = torch.arange(self.block_size) + 1
|
||||||
|
|
||||||
|
query_decay = torch.exp(-slope_rate * block_size_range[:, None])
|
||||||
|
key_decay = torch.exp(-slope_rate * (self.block_size - block_size_range[:, None]))
|
||||||
|
|
||||||
|
diagonal_decay = block_size_range[:, None] - block_size_range[None, :]
|
||||||
|
diagonal_decay = diagonal_decay[None, None, :, :]
|
||||||
|
diagonal_decay = slope_rate * diagonal_decay
|
||||||
|
diagonal_decay = torch.where(diagonal_decay >= 0, -diagonal_decay, float("-inf"))
|
||||||
|
diagonal_decay = torch.exp(diagonal_decay)
|
||||||
|
|
||||||
|
return query_decay, key_decay, diagonal_decay
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
attention_mask: Optional[torch.Tensor],
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
batch_size, seq_len, hidden_size = hidden_states.shape
|
||||||
|
num_blocks = (seq_len + self.block_size - 1) // self.block_size
|
||||||
|
|
||||||
|
qkv_states = self.act_fn(self.qkv_proj(hidden_states))
|
||||||
|
qkv_states = qkv_states.reshape(batch_size, seq_len, self.num_attention_heads, 3 * self.head_dim)
|
||||||
|
|
||||||
|
query_states, key_states, value_states = torch.split(qkv_states, self.head_dim, dim=3)
|
||||||
|
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
# calculated (K.T @ V) and saved as cache
|
||||||
|
attn_weights_inter = None
|
||||||
|
if past_key_value is not None:
|
||||||
|
attn_weights_inter = past_key_value.get_linear_cache(self.layer_idx)
|
||||||
|
|
||||||
|
if attn_weights_inter is None:
|
||||||
|
attn_weights_inter = torch.zeros(batch_size, self.num_attention_heads, self.head_dim, self.head_dim).to(
|
||||||
|
value_states
|
||||||
|
)
|
||||||
|
|
||||||
|
# apply attention_mask
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask.to(dtype=torch.bool) # Ensure it's a boolean tensor
|
||||||
|
value_states = value_states.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(-1), 0)
|
||||||
|
|
||||||
|
attn_output = []
|
||||||
|
for i in range(num_blocks):
|
||||||
|
start_idx = i * self.block_size
|
||||||
|
end_idx = min(start_idx + self.block_size, seq_len)
|
||||||
|
current_block_size = end_idx - start_idx
|
||||||
|
|
||||||
|
current_query_states = query_states[:, :, start_idx:end_idx]
|
||||||
|
current_key_states = key_states[:, :, start_idx:end_idx]
|
||||||
|
current_value_states = value_states[:, :, start_idx:end_idx]
|
||||||
|
|
||||||
|
current_query_decay = self.query_decay[:, :current_block_size]
|
||||||
|
current_key_decay = self.key_decay[:, -current_block_size:]
|
||||||
|
current_diagonal_decay = self.diagonal_decay[:, :, :current_block_size, :current_block_size]
|
||||||
|
block_decay = torch.exp(-self.slope_rate * current_block_size)
|
||||||
|
|
||||||
|
# intra: ( Q @ K.T ) @ V -> QK * V
|
||||||
|
attn_weights_intra = torch.matmul(current_query_states, current_key_states.transpose(-1, -2))
|
||||||
|
attn_output_intra = torch.matmul(attn_weights_intra * current_diagonal_decay, current_value_states)
|
||||||
|
|
||||||
|
# inter: Q @ ( K.T @ V ) -> Q * KV
|
||||||
|
attn_output_inter = torch.matmul(current_query_states * current_query_decay, attn_weights_inter)
|
||||||
|
|
||||||
|
# final attention output
|
||||||
|
current_attn_output = attn_output_inter + attn_output_intra
|
||||||
|
attn_output.append(current_attn_output)
|
||||||
|
|
||||||
|
# cacluate attn_weights_inter for next block or cache
|
||||||
|
next_attn_weights_inter = torch.matmul(
|
||||||
|
(current_key_states * current_key_decay).transpose(-1, -2), current_value_states
|
||||||
|
)
|
||||||
|
attn_weights_inter = attn_weights_inter * block_decay + next_attn_weights_inter
|
||||||
|
|
||||||
|
else:
|
||||||
|
ratio = torch.exp(-self.slope_rate)
|
||||||
|
attn_output = []
|
||||||
|
for i in range(seq_len):
|
||||||
|
current_query_states = query_states[:, :, i : i + 1]
|
||||||
|
current_key_states = key_states[:, :, i : i + 1]
|
||||||
|
current_value_states = value_states[:, :, i : i + 1]
|
||||||
|
|
||||||
|
current_attn_weights_inter = torch.matmul(current_key_states.transpose(-1, -2), current_value_states)
|
||||||
|
attn_weights_inter = ratio * attn_weights_inter + current_attn_weights_inter
|
||||||
|
current_attn_output = torch.matmul(current_query_states, attn_weights_inter)
|
||||||
|
|
||||||
|
attn_output.append(current_attn_output)
|
||||||
|
|
||||||
|
# concatenate attention outputs over all blocks
|
||||||
|
attn_output = torch.cat(attn_output, dim=-2)
|
||||||
|
|
||||||
|
# final output projection
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
attn_output = attn_output.reshape(batch_size, seq_len, self.num_attention_heads * self.head_dim)
|
||||||
|
attn_output = self.norm(attn_output)
|
||||||
|
attn_output = F.sigmoid(self.output_gate(hidden_states)) * attn_output
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
# update cache
|
||||||
|
if past_key_value is not None:
|
||||||
|
past_key_value.set_linear_cache(self.layer_idx, attn_weights_inter)
|
||||||
|
|
||||||
|
return attn_output, attn_weights_inter
|
||||||
|
|
||||||
|
|
||||||
|
class MiniMaxAttention(MixtralAttention):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MiniMaxDecoderLayer(MixtralDecoderLayer, GradientCheckpointingLayer):
|
||||||
|
def __init__(self, config: MiniMaxConfig, layer_idx: int):
|
||||||
|
super().__init__(config, layer_idx)
|
||||||
|
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
self.layer_type = config.layer_types[layer_idx]
|
||||||
|
self.mlp_alpha_factor = config.mlp_alpha_factor
|
||||||
|
self.mlp_beta_factor = config.mlp_beta_factor
|
||||||
|
|
||||||
|
if self.layer_type == "linear_attention":
|
||||||
|
self.self_attn = MiniMaxLightningAttention(config, layer_idx)
|
||||||
|
self.attn_alpha_factor = config.linear_attn_alpha_factor
|
||||||
|
self.attn_beta_factor = config.linear_attn_beta_factor
|
||||||
|
else:
|
||||||
|
self.self_attn = MiniMaxAttention(config, layer_idx)
|
||||||
|
self.attn_alpha_factor = config.full_attn_alpha_factor
|
||||||
|
self.attn_beta_factor = config.full_attn_beta_factor
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
output_router_logits: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
|
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`):
|
||||||
|
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||||
|
with `head_dim` being the embedding dimension of each attention head.
|
||||||
|
attention_mask (`torch.Tensor`, *optional*): attention mask of size
|
||||||
|
`(batch, sequence_length)` where padding elements are indicated by 0.
|
||||||
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
output_router_logits (`bool`, *optional*):
|
||||||
|
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
||||||
|
should not be returned during inference.
|
||||||
|
use_cache (`bool`, *optional*):
|
||||||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||||
|
(see `past_key_values`).
|
||||||
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
|
kwargs (`dict`, *optional*):
|
||||||
|
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||||
|
into the model
|
||||||
|
"""
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states, self_attn_weights = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = residual * self.attn_alpha_factor + hidden_states * self.attn_beta_factor
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
||||||
|
hidden_states = residual * self.mlp_alpha_factor + hidden_states * self.mlp_beta_factor
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (self_attn_weights,)
|
||||||
|
|
||||||
|
if output_router_logits:
|
||||||
|
outputs += (router_logits,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class MiniMaxPreTrainedModel(MixtralPreTrainedModel):
|
||||||
|
_supports_cache_class = True # Note: only supports MiniMaxCache
|
||||||
|
_supports_static_cache = False
|
||||||
|
_supports_quantized_cache = False
|
||||||
|
|
||||||
|
|
||||||
|
class MiniMaxModel(MixtralModel):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
output_router_logits: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||||
|
) -> MoeModelOutputWithPast:
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_router_logits = (
|
||||||
|
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
if use_cache and past_key_values is None:
|
||||||
|
past_key_values = MiniMaxCache()
|
||||||
|
elif use_cache and not isinstance(past_key_values, MiniMaxCache):
|
||||||
|
raise ValueError(
|
||||||
|
f"MiniMax uses cache of its own and is not compatible with `past_key_values` of type {type(past_key_values)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if cache_position is None:
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||||
|
)
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
|
||||||
|
causal_mask = mask_function(
|
||||||
|
config=self.config,
|
||||||
|
input_embeds=inputs_embeds,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
cache_position=cache_position,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
# create position embeddings to be shared across the decoder layers
|
||||||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
all_router_logits = () if output_router_logits else None
|
||||||
|
|
||||||
|
for decoder_layer in self.layers:
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
if decoder_layer.layer_type == "full_attention":
|
||||||
|
input_attention_mask = causal_mask
|
||||||
|
else:
|
||||||
|
# lightning attention uses original attention_mask, and uses it only for the first step
|
||||||
|
input_attention_mask = attention_mask
|
||||||
|
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
attention_mask=input_attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_router_logits=output_router_logits,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**flash_attn_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
if output_router_logits:
|
||||||
|
all_router_logits += (layer_outputs[-1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
return MoeModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
router_logits=all_router_logits,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MiniMaxForCausalLM(MixtralForCausalLM):
|
||||||
|
def forward(self, **super_kwargs):
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, MiniMaxForCausalLM
|
||||||
|
|
||||||
|
>>> model = MiniMaxForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
|
||||||
|
|
||||||
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
|
```"""
|
||||||
|
return super().forward(**super_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class MiniMaxForSequenceClassification(MixtralForSequenceClassification):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MiniMaxForTokenClassification(MixtralForTokenClassification):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MiniMaxForQuestionAnswering(MixtralForQuestionAnswering):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MiniMaxConfig",
|
||||||
|
"MiniMaxPreTrainedModel",
|
||||||
|
"MiniMaxModel",
|
||||||
|
"MiniMaxForCausalLM",
|
||||||
|
"MiniMaxForSequenceClassification",
|
||||||
|
"MiniMaxForTokenClassification",
|
||||||
|
"MiniMaxForQuestionAnswering",
|
||||||
|
]
|
0
tests/models/minimax/__init__.py
Normal file
0
tests/models/minimax/__init__.py
Normal file
279
tests/models/minimax/test_modeling_minimax.py
Normal file
279
tests/models/minimax/test_modeling_minimax.py
Normal file
@ -0,0 +1,279 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Testing suite for the PyTorch MiniMax model."""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from transformers import MiniMaxConfig, is_torch_available
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
require_flash_attn,
|
||||||
|
require_torch,
|
||||||
|
require_torch_accelerator,
|
||||||
|
require_torch_gpu,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
MiniMaxForCausalLM,
|
||||||
|
MiniMaxForQuestionAnswering,
|
||||||
|
MiniMaxForSequenceClassification,
|
||||||
|
MiniMaxForTokenClassification,
|
||||||
|
MiniMaxModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||||
|
|
||||||
|
|
||||||
|
class MiniMaxModelTester(CausalLMModelTester):
|
||||||
|
config_class = MiniMaxConfig
|
||||||
|
if is_torch_available():
|
||||||
|
base_model_class = MiniMaxModel
|
||||||
|
causal_lm_class = MiniMaxForCausalLM
|
||||||
|
sequence_class = MiniMaxForSequenceClassification
|
||||||
|
token_class = MiniMaxForTokenClassification
|
||||||
|
question_answering_class = MiniMaxForQuestionAnswering
|
||||||
|
|
||||||
|
def __init__(self, parent, layer_types=None, block_size=3):
|
||||||
|
super().__init__(parent)
|
||||||
|
self.layer_types = layer_types
|
||||||
|
self.block_size = block_size
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class MiniMaxModelTest(CausalLMModelTest, unittest.TestCase):
|
||||||
|
all_model_classes = (
|
||||||
|
(
|
||||||
|
MiniMaxModel,
|
||||||
|
MiniMaxForCausalLM,
|
||||||
|
MiniMaxForSequenceClassification,
|
||||||
|
MiniMaxForTokenClassification,
|
||||||
|
MiniMaxForQuestionAnswering,
|
||||||
|
)
|
||||||
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
|
)
|
||||||
|
pipeline_model_mapping = (
|
||||||
|
{
|
||||||
|
"feature-extraction": MiniMaxModel,
|
||||||
|
"text-classification": MiniMaxForSequenceClassification,
|
||||||
|
"token-classification": MiniMaxForTokenClassification,
|
||||||
|
"text-generation": MiniMaxForCausalLM,
|
||||||
|
"question-answering": MiniMaxForQuestionAnswering,
|
||||||
|
}
|
||||||
|
if is_torch_available()
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
|
test_headmasking = False
|
||||||
|
test_pruning = False
|
||||||
|
model_tester_class = MiniMaxModelTester
|
||||||
|
|
||||||
|
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
||||||
|
def is_pipeline_test_to_skip(
|
||||||
|
self,
|
||||||
|
pipeline_test_case_name,
|
||||||
|
config_class,
|
||||||
|
model_architecture,
|
||||||
|
tokenizer_name,
|
||||||
|
image_processor_name,
|
||||||
|
feature_extractor_name,
|
||||||
|
processor_name,
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
@require_flash_attn
|
||||||
|
@require_torch_gpu
|
||||||
|
@pytest.mark.flash_attn_test
|
||||||
|
@slow
|
||||||
|
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||||
|
self.skipTest(reason="MiniMax flash attention does not support right padding")
|
||||||
|
|
||||||
|
def test_load_balancing_loss(self):
|
||||||
|
r"""
|
||||||
|
Let's make sure we can actually compute the loss and do a backward on it.
|
||||||
|
"""
|
||||||
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.num_labels = 3
|
||||||
|
config.num_local_experts = 8
|
||||||
|
config.output_router_logits = True
|
||||||
|
input_ids = input_dict["input_ids"]
|
||||||
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
|
model = MiniMaxForCausalLM(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=attention_mask)
|
||||||
|
self.assertEqual(result.router_logits[0].shape, (91, config.num_local_experts))
|
||||||
|
torch.testing.assert_close(result.aux_loss.cpu(), torch.tensor(2, dtype=torch.float32), rtol=1e-2, atol=1e-2)
|
||||||
|
|
||||||
|
# First, we make sure that adding padding tokens doesn't change the loss
|
||||||
|
# loss(input_ids, attention_mask=None) == loss(input_ids + padding, attention_mask=attention_mask_with_padding)
|
||||||
|
pad_length = 1000
|
||||||
|
# Add padding tokens (assume that pad_token_id=1) to input_ids
|
||||||
|
padding_block = torch.ones(input_ids.shape[0], pad_length, dtype=torch.int32).to(torch_device)
|
||||||
|
padded_input_ids = torch.cat((padding_block, input_ids), dim=1) # this is to simulate padding to the left
|
||||||
|
padded_attention_mask = padded_input_ids.ne(1).to(torch_device)
|
||||||
|
|
||||||
|
padded_result = model(padded_input_ids, attention_mask=padded_attention_mask)
|
||||||
|
torch.testing.assert_close(result.aux_loss.cpu(), padded_result.aux_loss.cpu(), rtol=1e-4, atol=1e-4)
|
||||||
|
|
||||||
|
# We make sure that the loss of including padding tokens != the loss without padding tokens
|
||||||
|
# if attention_mask=None --> we don't exclude padding tokens
|
||||||
|
include_padding_result = model(padded_input_ids, attention_mask=None)
|
||||||
|
|
||||||
|
# This is to mimic torch.testing.assert_not_close
|
||||||
|
self.assertNotAlmostEqual(include_padding_result.aux_loss.item(), result.aux_loss.item())
|
||||||
|
|
||||||
|
def _check_attentions_for_generate(
|
||||||
|
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
|
||||||
|
):
|
||||||
|
self.assertIsInstance(attentions, tuple)
|
||||||
|
self.assertListEqual(
|
||||||
|
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
||||||
|
)
|
||||||
|
self.assertEqual(len(attentions), (output_length - prompt_length))
|
||||||
|
use_cache = decoder_past_key_values is not None
|
||||||
|
|
||||||
|
for generated_length, iter_attentions in enumerate(attentions):
|
||||||
|
# regardless of using cache, the first forward pass will have the full prompt as input
|
||||||
|
if use_cache and generated_length > 0:
|
||||||
|
model_input_length = 1
|
||||||
|
else:
|
||||||
|
model_input_length = prompt_length + generated_length
|
||||||
|
|
||||||
|
expected_shape = (
|
||||||
|
batch_size,
|
||||||
|
config.num_attention_heads,
|
||||||
|
model_input_length,
|
||||||
|
prompt_length + generated_length,
|
||||||
|
)
|
||||||
|
for layer_idx, layer_attention in enumerate(iter_attentions):
|
||||||
|
if config.layer_types[layer_idx] == "full_attention":
|
||||||
|
self.assertEqual(layer_attention.shape, expected_shape)
|
||||||
|
|
||||||
|
def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config):
|
||||||
|
self.assertIsInstance(decoder_past_key_values, (tuple, Cache))
|
||||||
|
|
||||||
|
# (batch, head, seq_length, head_features)
|
||||||
|
key_value_cache_expected_shape = (
|
||||||
|
batch_size,
|
||||||
|
config.num_key_value_heads,
|
||||||
|
cache_length,
|
||||||
|
config.hidden_size // config.num_attention_heads,
|
||||||
|
)
|
||||||
|
# (batch, head, head_features, head_features)
|
||||||
|
linear_cache_expected_shape = (
|
||||||
|
batch_size,
|
||||||
|
config.num_attention_heads,
|
||||||
|
config.hidden_size // config.num_attention_heads,
|
||||||
|
config.hidden_size // config.num_attention_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
for layer_idx in range(config.num_hidden_layers):
|
||||||
|
if config.layer_types[layer_idx] == "full_attention":
|
||||||
|
self.assertEqual(decoder_past_key_values[layer_idx][0].shape, key_value_cache_expected_shape)
|
||||||
|
self.assertEqual(decoder_past_key_values[layer_idx][1].shape, key_value_cache_expected_shape)
|
||||||
|
else:
|
||||||
|
self.assertEqual(decoder_past_key_values[layer_idx][0].shape, linear_cache_expected_shape)
|
||||||
|
|
||||||
|
@pytest.mark.generate
|
||||||
|
def test_past_key_values_format(self, custom_all_cache_shapes=None):
|
||||||
|
"""
|
||||||
|
Test that the KV cache is formatted correctly.
|
||||||
|
"""
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
model = model_class(config).to(torch_device)
|
||||||
|
model = model.eval()
|
||||||
|
if "use_cache" not in inputs:
|
||||||
|
inputs["use_cache"] = True
|
||||||
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
past_kv = outputs["past_key_values"]
|
||||||
|
|
||||||
|
batch_size, seq_length = inputs["input_ids"].shape
|
||||||
|
self._check_past_key_values_for_generate(batch_size, past_kv, seq_length, config)
|
||||||
|
|
||||||
|
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
|
||||||
|
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
|
||||||
|
def test_contrastive_generate_low_memory(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
|
||||||
|
def test_assisted_decoding_sample(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
|
||||||
|
def test_assisted_decoding_matches_greedy_search_0_random(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
|
||||||
|
def test_assisted_decoding_matches_greedy_search_1_same(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
|
||||||
|
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_torch_accelerator
|
||||||
|
@slow
|
||||||
|
class MiniMaxIntegrationTest(unittest.TestCase):
|
||||||
|
def test_small_model_logits(self):
|
||||||
|
model_id = "geetu040/MiniMax-tiny"
|
||||||
|
dummy_input = torch.LongTensor([[0, 1, 0], [0, 1, 0]]).to(torch_device)
|
||||||
|
|
||||||
|
model = MiniMaxForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to(
|
||||||
|
torch_device
|
||||||
|
)
|
||||||
|
expected_slice = torch.tensor(
|
||||||
|
[[1.0312, -0.5156, -0.3262], [-0.1152, 0.4336, 0.2412], [1.2188, -0.5898, -0.0381]]
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(dummy_input).logits
|
||||||
|
|
||||||
|
logits = logits.float()
|
||||||
|
|
||||||
|
torch.testing.assert_close(logits[0, :3, :3], expected_slice, atol=1e-3, rtol=1e-3)
|
||||||
|
torch.testing.assert_close(logits[1, :3, :3], expected_slice, atol=1e-3, rtol=1e-3)
|
||||||
|
|
||||||
|
def test_small_model_generation(self):
|
||||||
|
model_id = "geetu040/MiniMax-tiny"
|
||||||
|
dummy_input = torch.LongTensor([[0, 1, 0], [0, 1, 0]]).to(torch_device)
|
||||||
|
|
||||||
|
model = MiniMaxForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to(
|
||||||
|
torch_device
|
||||||
|
)
|
||||||
|
expected_slice = (
|
||||||
|
torch.tensor([[0, 1, 0, 933, 307, 3102, 2457, 1208], [0, 1, 0, 933, 307, 3102, 2457, 1208]])
|
||||||
|
.to(torch.int64)
|
||||||
|
.to(torch_device)
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = model.generate(dummy_input, max_new_tokens=5, do_sample=False)
|
||||||
|
|
||||||
|
torch.testing.assert_close(outputs, expected_slice, atol=1e-3, rtol=1e-3)
|
@ -3946,7 +3946,7 @@ class ModelTesterMixin:
|
|||||||
if not self.has_attentions:
|
if not self.has_attentions:
|
||||||
self.skipTest(reason="Model architecture does not support attentions")
|
self.skipTest(reason="Model architecture does not support attentions")
|
||||||
|
|
||||||
WINDOW_ATTENTION_MODELS = ["mistral", "mixtral", "qwen2", "qwen_moe", "starcoder2"]
|
WINDOW_ATTENTION_MODELS = ["mistral", "mixtral", "minimax", "qwen2", "qwen_moe", "starcoder2"]
|
||||||
|
|
||||||
if len(self.all_generative_model_classes) == 0:
|
if len(self.all_generative_model_classes) == 0:
|
||||||
self.skipTest(f"No generative model classes for {self.__class__.__name__}")
|
self.skipTest(f"No generative model classes for {self.__class__.__name__}")
|
||||||
|
Loading…
Reference in New Issue
Block a user