Skip to content

Commit 5eea3b9

Browse files
geetu040Shakib-IOCyrilvallez
authored andcommitted
Add support for MiniMax's MiniMax-Text-01 (huggingface#35831)
* end-to-end architecture * lightning-attn: refactor, clean, optimize * put minimax_text_01 in other files * use latest __init__ standards and auto-generate modular * support attention_mask for lightning-attn * Revert "use latest __init__ standards and auto-generate modular" This reverts commit d8d3c40. * fix modular conversion * pass both attention masks instead of tuple * formatting * Updated Dynamic Cache * created MiniMaxText01Cache * fix hardcoded slope_rate * update attn_type_list in config * fix lightning when use_cache=False * copy tests from mixtral * (checkpoint) all tests pass for normal attention * fix all unittests * fix import sorting * fix consistency and formatting tests * fix config * update tests, since changes in main * fix seq_len error * create dummy docs * fix checkpoint * add checkpoint in config docstring * run modular_conversion * update docs * fix checkpoint path and update tests * fix ruff * remove repeated expected_slice * update docs * rename "minimax-text-01" to "minimax" * inherit config from mixtral * remove from docs in other languages * undo files that should be untouched * move minimax to end in conversation docs * use MiniMaxForCausalLM as it is * ruff fixes * run modular * fix docstring example in causallm * refactor attention loop and decay factors * refactor config in modular * run modular * refactor cache * rename static_cache to linear_cache * make positional embeddings necessary * remove unnecessary layernorms declarations * fix import in tests * refactor attention in next tokens * remove outdated code * formatting and modular * update tests * rename layernorm alpha/beta factors * register decay factors as buffers * remove unused declarations of decay factors * update config for alpha/beta factors * run modular * remove head_dim in tests * remove minimax from fx.py * remove stuff that is not really needed * update __init__ * update qkv torch.split Co-authored-by: Cyril Vallez <[email protected]> * fix qkv torch.split * quality fixes * remove mistakenly added dummy * purge unused ModelTester code * fix-copies * run fix-copies * fix head_dim * write cache formatting tests * remove postnorm * avoid contiguous in attention current states * update expected_slice * add generation test for integration * fix dtype in generation test * update authors * update with changes in main * update graident checkpointing and minor fixes * fix mutable attn_type_list * rename: attn_type -> layer_type * update for layer_types * update integration tests * update checkpoint * clean overview in docs --------- Co-authored-by: Shakib-IO <[email protected]> Co-authored-by: Cyril Vallez <[email protected]>
1 parent 8db1e67 commit 5eea3b9

File tree

15 files changed

+2650
-1
lines changed

15 files changed

+2650
-1
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,8 @@
555555
title: MegatronBERT
556556
- local: model_doc/megatron_gpt2
557557
title: MegatronGPT2
558+
- local: model_doc/minimax
559+
title: MiniMax
558560
- local: model_doc/mistral
559561
title: Mistral
560562
- local: model_doc/mixtral

docs/source/en/model_doc/minimax.md

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
<!--Copyright 2025 MiniMaxAI and The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# MiniMax
18+
19+
## Overview
20+
21+
The DepthPro model was proposed in [MiniMax-01: Scaling Foundation Models with Lightning Attention](https://arxiv.org/abs/2501.08313) by MiniMax, Aonian Li, Bangwei Gong, Bo Yang, Boji Shan, Chang Liu, Cheng Zhu, Chunhao Zhang, Congchao Guo, Da Chen, Dong Li, Enwei Jiao, Gengxin Li, Guojun Zhang, Haohai Sun, Houze Dong, Jiadai Zhu, Jiaqi Zhuang, Jiayuan Song, Jin Zhu, Jingtao Han, Jingyang Li, Junbin Xie, Junhao Xu, Junjie Yan, Kaishun Zhang, Kecheng Xiao, Kexi Kang, Le Han, Leyang Wang, Lianfei Yu, Liheng Feng, Lin Zheng, Linbo Chai, Long Xing, Meizhi Ju, Mingyuan Chi, Mozhi Zhang, Peikai Huang, Pengcheng Niu, Pengfei Li, Pengyu Zhao, Qi Yang, Qidi Xu, Qiexiang Wang, Qin Wang, Qiuhui Li, Ruitao Leng, Shengmin Shi, Shuqi Yu, Sichen Li, Songquan Zhu, Tao Huang, Tianrun Liang, Weigao Sun, Weixuan Sun, Weiyu Cheng, Wenkai Li, Xiangjun Song, Xiao Su, Xiaodong Han, Xinjie Zhang, Xinzhu Hou, Xu Min, Xun Zou, Xuyang Shen, Yan Gong, Yingjie Zhu, Yipeng Zhou, Yiran Zhong, Yongyi Hu, Yuanxiang Fan, Yue Yu, Yufeng Yang, Yuhao Li, Yunan Huang, Yunji Li, Yunpeng Huang, Yunzhi Xu, Yuxin Mao, Zehan Li, Zekang Li, Zewei Tao, Zewen Ying, Zhaoyang Cong, Zhen Qin, Zhenhua Fan, Zhihang Yu, Zhuo Jiang, Zijia Wu.
22+
23+
The abstract from the paper is the following:
24+
25+
*We introduce MiniMax-01 series, including MiniMax-Text-01 and MiniMax-VL-01, which are comparable to top-tier models while offering superior capabilities in processing longer contexts. The core lies in lightning attention and its efficient scaling. To maximize computational capacity, we integrate it with Mixture of Experts (MoE), creating a model with 32 experts and 456 billion total parameters, of which 45.9 billion are activated for each token. We develop an optimized parallel strategy and highly efficient computation-communication overlap techniques for MoE and lightning attention. This approach enables us to conduct efficient training and inference on models with hundreds of billions of parameters across contexts spanning millions of tokens. The context window of MiniMax-Text-01 can reach up to 1 million tokens during training and extrapolate to 4 million tokens during inference at an affordable cost. Our vision-language model, MiniMax-VL-01 is built through continued training with 512 billion vision-language tokens. Experiments on both standard and in-house benchmarks show that our models match the performance of state-of-the-art models like GPT-4o and Claude-3.5-Sonnet while offering 20-32 times longer context window.*
26+
27+
### Architectural details
28+
29+
MiniMax is a powerful language model with 456 billion total parameters, of which 45.9 billion are activated per token. To better unlock the long context capabilities of the model, MiniMax adopts a hybrid architecture that combines Lightning Attention, Softmax Attention and Mixture-of-Experts (MoE). Leveraging advanced parallel strategies and innovative compute-communication overlap methods—such as Linear Attention Sequence Parallelism Plus (LASP+), varlen ring attention, Expert Tensor Parallel (ETP), etc., MiniMax's training context length is extended to 1 million tokens, and it can handle a context of up to 4 million tokens during the inference. On various academic benchmarks, MiniMax also demonstrates the performance of a top-tier model.
30+
31+
The architecture of MiniMax is briefly described as follows:
32+
33+
- Total Parameters: 456B
34+
- Activated Parameters per Token: 45.9B
35+
- Number Layers: 80
36+
- Hybrid Attention: a softmax attention is positioned after every 7 lightning attention.
37+
- Number of attention heads: 64
38+
- Attention head dimension: 128
39+
- Mixture of Experts:
40+
- Number of experts: 32
41+
- Expert hidden dimension: 9216
42+
- Top-2 routing strategy
43+
- Positional Encoding: Rotary Position Embedding (RoPE) applied to half of the attention head dimension with a base frequency of 10,000,000
44+
- Hidden Size: 6144
45+
- Vocab Size: 200,064
46+
47+
For more details refer to the [release blog post](https://www.minimaxi.com/en/news/minimax-01-series-2).
48+
49+
### License
50+
51+
`MiniMax` is released under the MINIMAX MODEL LICENSE AGREEMENT.
52+
53+
## Usage tips
54+
55+
The pre-trained model can be used as follows:
56+
57+
```python
58+
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
59+
60+
>>> model = AutoModelForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf", device_map="auto")
61+
>>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
62+
63+
>>> messages = [
64+
... {"role": "user", "content": "What is your favourite condiment?"},
65+
... {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
66+
... {"role": "user", "content": "Do you have mayonnaise recipes?"}
67+
... ]
68+
69+
>>> model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")
70+
71+
>>> generated_ids = model.generate(model_inputs, max_new_tokens=100, do_sample=True)
72+
>>> tokenizer.batch_decode(generated_ids)[0]
73+
"Mayonnaise can be made as follows: (...)"
74+
```
75+
76+
As can be seen, the instruction-tuned model requires a [chat template](../chat_templating) to be applied to make sure the inputs are prepared in the right format.
77+
78+
## Speeding up MiniMax by using Flash Attention
79+
80+
The code snippets above showcase inference without any optimization tricks. However, one can drastically speed up the model by leveraging [Flash Attention](../perf_train_gpu_one#flash-attention-2), which is a faster implementation of the attention mechanism used inside the model.
81+
82+
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
83+
84+
```bash
85+
pip install -U flash-attn --no-build-isolation
86+
```
87+
88+
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of the [flash attention repository](https://github.com/Dao-AILab/flash-attention). Make also sure to load your model in half-precision (e.g. `torch.float16`)
89+
90+
To load and run a model using Flash Attention-2, refer to the snippet below:
91+
92+
```python
93+
>>> import torch
94+
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
95+
96+
>>> model = AutoModelForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf", torch_dtype=torch.float16, attn_implementation="flash_attention_2", device_map="auto")
97+
>>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
98+
99+
>>> prompt = "My favourite condiment is"
100+
101+
>>> model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
102+
>>> model.to(device)
103+
104+
>>> generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
105+
>>> tokenizer.batch_decode(generated_ids)[0]
106+
"The expected output"
107+
```
108+
109+
### Sliding window Attention
110+
111+
The current implementation supports the sliding window attention mechanism and memory efficient cache management.
112+
To enable sliding window attention, just make sure to have a `flash-attn` version that is compatible with sliding window attention (`>=2.3.0`).
113+
114+
The Flash Attention-2 model uses also a more memory efficient cache slicing mechanism - as recommended per the official implementation of Mistral model that use rolling cache mechanism we keep the cache size fixed (`self.config.sliding_window`), support batched generation only for `padding_side="left"` and use the absolute position of the current token to compute the positional embedding.
115+
116+
## Shrinking down MiniMax using quantization
117+
118+
As the MiniMax model has 456 billion parameters, that would require about 912GB of GPU RAM in half precision (float16), since each parameter is stored in 2 bytes. However, one can shrink down the size of the model using [quantization](../quantization.md). If the model is quantized to 4 bits (or half a byte per parameter), about 228 GB of RAM is required.
119+
120+
Quantizing a model is as simple as passing a `quantization_config` to the model. Below, we'll leverage the bitsandbytes quantization library (but refer to [this page](../quantization.md) for alternative quantization methods):
121+
122+
```python
123+
>>> import torch
124+
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
125+
126+
>>> # specify how to quantize the model
127+
>>> quantization_config = BitsAndBytesConfig(
128+
... load_in_4bit=True,
129+
... bnb_4bit_quant_type="nf4",
130+
... bnb_4bit_compute_dtype="torch.float16",
131+
... )
132+
133+
>>> model = AutoModelForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf", quantization_config=True, device_map="auto")
134+
>>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
135+
136+
>>> prompt = "My favourite condiment is"
137+
138+
>>> messages = [
139+
... {"role": "user", "content": "What is your favourite condiment?"},
140+
... {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
141+
... {"role": "user", "content": "Do you have mayonnaise recipes?"}
142+
... ]
143+
144+
>>> model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")
145+
146+
>>> generated_ids = model.generate(model_inputs, max_new_tokens=100, do_sample=True)
147+
>>> tokenizer.batch_decode(generated_ids)[0]
148+
"The expected output"
149+
```
150+
151+
This model was contributed by [geetu040](https://github.com/geetu040).
152+
The original code can be found [here](https://huggingface.co/MiniMaxAI/MiniMax-Text-01-hf/blob/main/modeling_minimax.py).
153+
154+
## Resources
155+
156+
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with MiniMax. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
157+
158+
<PipelineTag pipeline="text-generation"/>
159+
160+
- The [Alignment Handbook](https://github.com/huggingface/alignment-handbook) by Hugging Face includes scripts and recipes to perform supervised fine-tuning (SFT) and direct preference optimization with Mistral-7B. This includes scripts for full fine-tuning, QLoRa on a single GPU as well as multi-GPU fine-tuning.
161+
- [Causal language modeling task guide](../tasks/language_modeling)
162+
163+
## MiniMaxConfig
164+
165+
[[autodoc]] MiniMaxConfig
166+
167+
## MiniMaxModel
168+
169+
[[autodoc]] MiniMaxModel
170+
- forward
171+
172+
## MiniMaxForCausalLM
173+
174+
[[autodoc]] MiniMaxForCausalLM
175+
- forward
176+
177+
## MiniMaxForSequenceClassification
178+
179+
[[autodoc]] MiniMaxForSequenceClassification
180+
- forward
181+
182+
## MiniMaxForTokenClassification
183+
184+
[[autodoc]] MiniMaxForTokenClassification
185+
- forward
186+
187+
## MiniMaxForQuestionAnswering
188+
[[autodoc]] MiniMaxForQuestionAnswering
189+
- forward

src/transformers/configuration_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,6 +1218,7 @@ def recursive_diff_dict(dict_a, dict_b, config_obj=None):
12181218
"full_attention",
12191219
"sliding_attention",
12201220
"chunked_attention",
1221+
"linear_attention", # used in minimax
12211222
)
12221223

12231224

src/transformers/generation/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1976,6 +1976,7 @@ def _supports_default_dynamic_cache(self) -> bool:
19761976
and "jamba" not in self.__class__.__name__.lower()
19771977
and "zamba" not in self.__class__.__name__.lower()
19781978
and "bamba" not in self.__class__.__name__.lower()
1979+
and "minimax" not in self.__class__.__name__.lower()
19791980
)
19801981

19811982
def _prepare_cache_for_generation(

src/transformers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@
185185
from .megatron_gpt2 import *
186186
from .mgp_str import *
187187
from .mimi import *
188+
from .minimax import *
188189
from .mistral import *
189190
from .mistral3 import *
190191
from .mixtral import *

src/transformers/models/auto/configuration_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@
211211
("megatron-bert", "MegatronBertConfig"),
212212
("mgp-str", "MgpstrConfig"),
213213
("mimi", "MimiConfig"),
214+
("minimax", "MiniMaxConfig"),
214215
("mistral", "MistralConfig"),
215216
("mistral3", "Mistral3Config"),
216217
("mixtral", "MixtralConfig"),
@@ -586,6 +587,7 @@
586587
("megatron_gpt2", "Megatron-GPT2"),
587588
("mgp-str", "MGP-STR"),
588589
("mimi", "Mimi"),
590+
("minimax", "MiniMax"),
589591
("mistral", "Mistral"),
590592
("mistral3", "Mistral3"),
591593
("mixtral", "Mixtral"),

src/transformers/models/auto/modeling_auto.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@
201201
("megatron-bert", "MegatronBertModel"),
202202
("mgp-str", "MgpstrForSceneTextRecognition"),
203203
("mimi", "MimiModel"),
204+
("minimax", "MiniMaxModel"),
204205
("mistral", "MistralModel"),
205206
("mistral3", "Mistral3Model"),
206207
("mixtral", "MixtralModel"),
@@ -594,6 +595,7 @@
594595
("mbart", "MBartForCausalLM"),
595596
("mega", "MegaForCausalLM"),
596597
("megatron-bert", "MegatronBertForCausalLM"),
598+
("minimax", "MiniMaxForCausalLM"),
597599
("mistral", "MistralForCausalLM"),
598600
("mixtral", "MixtralForCausalLM"),
599601
("mllama", "MllamaForCausalLM"),
@@ -1106,6 +1108,7 @@
11061108
("mbart", "MBartForSequenceClassification"),
11071109
("mega", "MegaForSequenceClassification"),
11081110
("megatron-bert", "MegatronBertForSequenceClassification"),
1111+
("minimax", "MiniMaxForSequenceClassification"),
11091112
("mistral", "MistralForSequenceClassification"),
11101113
("mixtral", "MixtralForSequenceClassification"),
11111114
("mobilebert", "MobileBertForSequenceClassification"),
@@ -1197,6 +1200,7 @@
11971200
("mbart", "MBartForQuestionAnswering"),
11981201
("mega", "MegaForQuestionAnswering"),
11991202
("megatron-bert", "MegatronBertForQuestionAnswering"),
1203+
("minimax", "MiniMaxForQuestionAnswering"),
12001204
("mistral", "MistralForQuestionAnswering"),
12011205
("mixtral", "MixtralForQuestionAnswering"),
12021206
("mobilebert", "MobileBertForQuestionAnswering"),
@@ -1303,6 +1307,7 @@
13031307
("markuplm", "MarkupLMForTokenClassification"),
13041308
("mega", "MegaForTokenClassification"),
13051309
("megatron-bert", "MegatronBertForTokenClassification"),
1310+
("minimax", "MiniMaxForTokenClassification"),
13061311
("mistral", "MistralForTokenClassification"),
13071312
("mixtral", "MixtralForTokenClassification"),
13081313
("mobilebert", "MobileBertForTokenClassification"),

src/transformers/models/auto/tokenization_auto.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,13 @@
342342
("mega", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
343343
("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
344344
("mgp-str", ("MgpstrTokenizer", None)),
345+
(
346+
"minimax",
347+
(
348+
"GPT2Tokenizer" if is_sentencepiece_available() else None,
349+
"GPT2TokenizerFast" if is_tokenizers_available() else None,
350+
),
351+
),
345352
(
346353
"mistral",
347354
(
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# coding=utf-8
2+
# Copyright 2025 MiniMaxAI and HuggingFace Inc. teams. All rights reserved.
3+
#
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
from typing import TYPE_CHECKING
17+
18+
from ...utils import _LazyModule
19+
from ...utils.import_utils import define_import_structure
20+
21+
22+
if TYPE_CHECKING:
23+
from .configuration_minimax import *
24+
from .modeling_minimax import *
25+
else:
26+
import sys
27+
28+
_file = globals()["__file__"]
29+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

0 commit comments

Comments
 (0)