Skip to content

Add support for MiniMax's MiniMax-Text-01 #35831

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 103 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from 97 commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
3a046a4
end-to-end architecture
geetu040 Jan 22, 2025
519eda3
lightning-attn: refactor, clean, optimize
geetu040 Jan 23, 2025
c54f804
put minimax_text_01 in other files
geetu040 Jan 24, 2025
d8d3c40
use latest __init__ standards and auto-generate modular
geetu040 Jan 27, 2025
8d654d8
support attention_mask for lightning-attn
geetu040 Jan 27, 2025
70b787c
Merge branch 'main' into minimax-text-01
geetu040 Jan 27, 2025
5b40a5c
Revert "use latest __init__ standards and auto-generate modular"
geetu040 Jan 27, 2025
a93ee3f
fix modular conversion
geetu040 Jan 27, 2025
92f7963
pass both attention masks instead of tuple
geetu040 Jan 27, 2025
ecdc0eb
formatting
geetu040 Jan 27, 2025
0027c0e
Updated Dynamic Cache
Shakib-IO Jan 30, 2025
e117d26
created MiniMaxText01Cache
geetu040 Jan 31, 2025
7d7ae06
fix hardcoded slope_rate
geetu040 Feb 1, 2025
d209bd3
update attn_type_list in config
geetu040 Feb 1, 2025
d6eb561
fix lightning when use_cache=False
geetu040 Feb 1, 2025
99db3a0
Merge branch 'main' into minimax-text-01
geetu040 Feb 7, 2025
e459521
copy tests from mixtral
geetu040 Feb 7, 2025
866ba89
(checkpoint) all tests pass for normal attention
geetu040 Feb 7, 2025
1a6f086
fix all unittests
geetu040 Feb 10, 2025
e27af34
Merge branch 'main' into minimax-text-01
geetu040 Feb 10, 2025
5ba152b
fix import sorting
geetu040 Feb 10, 2025
8fd17ca
fix consistency and formatting tests
geetu040 Feb 10, 2025
d907633
fix config
geetu040 Feb 12, 2025
09b07ea
Merge branch 'main' into minimax-text-01
geetu040 Feb 12, 2025
c08a619
update tests, since changes in main
geetu040 Feb 12, 2025
8117b2d
fix seq_len error
geetu040 Feb 13, 2025
2f9dc21
create dummy docs
geetu040 Feb 13, 2025
299e707
fix checkpoint
geetu040 Feb 13, 2025
33e4157
add checkpoint in config docstring
geetu040 Feb 13, 2025
852ce2e
Merge branch 'main' into minimax-text-01
geetu040 Feb 13, 2025
85b3dcb
run modular_conversion
geetu040 Feb 13, 2025
cdaa09b
update docs
geetu040 Feb 13, 2025
4eb7a08
Merge branch 'main' into minimax-text-01
geetu040 Feb 14, 2025
1f63bbb
fix checkpoint path and update tests
geetu040 Feb 15, 2025
d95ea2d
fix ruff
geetu040 Feb 16, 2025
2150a5f
Merge branch 'main' into minimax-text-01
geetu040 Feb 16, 2025
720fd4f
remove repeated expected_slice
geetu040 Feb 16, 2025
c527d34
update docs
geetu040 Feb 16, 2025
3603b4d
Merge branch 'main' into minimax-text-01
geetu040 Feb 18, 2025
adfc7c5
Merge branch 'main' into minimax-text-01
geetu040 Feb 24, 2025
f1f669b
rename "minimax-text-01" to "minimax"
geetu040 Feb 24, 2025
8309220
inherit config from mixtral
geetu040 Feb 24, 2025
3dd21d0
remove from docs in other languages
geetu040 Feb 24, 2025
fac8ba6
undo files that should be untouched
geetu040 Feb 24, 2025
3dfb5c8
move minimax to end in conversation docs
geetu040 Feb 24, 2025
d6304e4
use MiniMaxForCausalLM as it is
geetu040 Feb 24, 2025
7721cdd
ruff fixes
geetu040 Feb 24, 2025
f7004f5
Merge branch "origin/main"; resolve conflicts
geetu040 Mar 6, 2025
7def81b
run modular
geetu040 Mar 6, 2025
b543d0b
fix docstring example in causallm
geetu040 Mar 6, 2025
64f6e22
refactor attention loop and decay factors
geetu040 Mar 7, 2025
327c033
refactor config in modular
geetu040 Mar 7, 2025
caa1529
run modular
geetu040 Mar 7, 2025
d3f206d
refactor cache
geetu040 Mar 8, 2025
9398d79
Merge remote-tracking branch 'origin/main' into minimax-text-01
geetu040 Mar 8, 2025
cd51672
rename static_cache to linear_cache
geetu040 Mar 8, 2025
693eab2
make positional embeddings necessary
geetu040 Mar 8, 2025
6f953ee
remove unnecessary layernorms declarations
geetu040 Mar 8, 2025
58b95a0
fix import in tests
geetu040 Mar 8, 2025
ffb7b22
refactor attention in next tokens
geetu040 Mar 8, 2025
8de6ae1
remove outdated code
geetu040 Mar 8, 2025
3544b5e
formatting and modular
geetu040 Mar 8, 2025
968f357
update tests
geetu040 Mar 8, 2025
d285c8f
rename layernorm alpha/beta factors
geetu040 Mar 8, 2025
bc47dfb
register decay factors as buffers
geetu040 Mar 8, 2025
a6930a8
remove unused declarations of decay factors
geetu040 Mar 8, 2025
7549354
update config for alpha/beta factors
geetu040 Mar 17, 2025
532c88b
Merge branch 'main' into minimax-text-01
geetu040 Mar 17, 2025
c304884
run modular
geetu040 Mar 17, 2025
8837c90
remove head_dim in tests
geetu040 Mar 17, 2025
15f01aa
Merge branch main into minimax-text-01 (resolve conflicts)
geetu040 Apr 6, 2025
b02abe1
remove minimax from fx.py
geetu040 Apr 6, 2025
ef4c5ec
remove stuff that is not really needed
geetu040 Apr 6, 2025
f625d76
update __init__
geetu040 Apr 6, 2025
2428fa2
update qkv torch.split
geetu040 Apr 6, 2025
141cc39
fix qkv torch.split
geetu040 Apr 6, 2025
04f4511
quality fixes
geetu040 Apr 6, 2025
13b35c9
remove mistakenly added dummy
geetu040 Apr 6, 2025
91b71b7
purge unused ModelTester code
geetu040 Apr 6, 2025
8b00381
fix-copies
geetu040 Apr 6, 2025
34fc564
Merge branch "main"
geetu040 Apr 23, 2025
c2a5c4a
run fix-copies
geetu040 Apr 23, 2025
cd15fcd
fix head_dim
geetu040 Apr 23, 2025
fb68e6a
write cache formatting tests
geetu040 Apr 24, 2025
7c1d499
remove postnorm
geetu040 Apr 24, 2025
9d76269
avoid contiguous in attention current states
geetu040 Apr 24, 2025
44b1aae
update expected_slice
geetu040 Apr 24, 2025
1cb563f
add generation test for integration
geetu040 Apr 24, 2025
818bf04
fix dtype in generation test
geetu040 Apr 24, 2025
70caedc
Merge branch 'main' into minimax-text-01
geetu040 Apr 24, 2025
9c5397d
update authors
geetu040 Apr 24, 2025
edb9337
Merge branch 'main' into minimax-text-01
geetu040 May 26, 2025
8e418ec
update with changes in main
geetu040 May 26, 2025
9754ae8
update graident checkpointing and minor fixes
geetu040 May 27, 2025
3c1ff61
fix mutable attn_type_list
geetu040 May 27, 2025
c8f7ed2
rename: attn_type -> layer_type
geetu040 May 27, 2025
a107b2e
Merge branch "main" (resolve conflicts)
geetu040 May 29, 2025
51f5dd1
Merge branch 'main' into minimax-text-01
geetu040 Jun 3, 2025
7837c75
update for layer_types
geetu040 Jun 3, 2025
278aad7
update integration tests
geetu040 Jun 3, 2025
7ee260d
update checkpoint
geetu040 Jun 3, 2025
8e156d4
clean overview in docs
geetu040 Jun 3, 2025
9a94d04
Merge branch 'main' into minimax-text-01
geetu040 Jun 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,8 @@
title: MegatronBERT
- local: model_doc/megatron_gpt2
title: MegatronGPT2
- local: model_doc/minimax
title: MiniMax
- local: model_doc/mistral
title: Mistral
- local: model_doc/mixtral
Expand Down
197 changes: 197 additions & 0 deletions docs/source/en/model_doc/minimax.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
<!--Copyright 2025 MiniMaxAI and The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# MiniMax

## Overview

The DepthPro model was proposed in [MiniMax-01: Scaling Foundation Models with Lightning Attention](https://arxiv.org/abs/2501.08313) by MiniMax, Aonian Li, Bangwei Gong, Bo Yang, Boji Shan, Chang Liu, Cheng Zhu, Chunhao Zhang, Congchao Guo, Da Chen, Dong Li, Enwei Jiao, Gengxin Li, Guojun Zhang, Haohai Sun, Houze Dong, Jiadai Zhu, Jiaqi Zhuang, Jiayuan Song, Jin Zhu, Jingtao Han, Jingyang Li, Junbin Xie, Junhao Xu, Junjie Yan, Kaishun Zhang, Kecheng Xiao, Kexi Kang, Le Han, Leyang Wang, Lianfei Yu, Liheng Feng, Lin Zheng, Linbo Chai, Long Xing, Meizhi Ju, Mingyuan Chi, Mozhi Zhang, Peikai Huang, Pengcheng Niu, Pengfei Li, Pengyu Zhao, Qi Yang, Qidi Xu, Qiexiang Wang, Qin Wang, Qiuhui Li, Ruitao Leng, Shengmin Shi, Shuqi Yu, Sichen Li, Songquan Zhu, Tao Huang, Tianrun Liang, Weigao Sun, Weixuan Sun, Weiyu Cheng, Wenkai Li, Xiangjun Song, Xiao Su, Xiaodong Han, Xinjie Zhang, Xinzhu Hou, Xu Min, Xun Zou, Xuyang Shen, Yan Gong, Yingjie Zhu, Yipeng Zhou, Yiran Zhong, Yongyi Hu, Yuanxiang Fan, Yue Yu, Yufeng Yang, Yuhao Li, Yunan Huang, Yunji Li, Yunpeng Huang, Yunzhi Xu, Yuxin Mao, Zehan Li, Zekang Li, Zewei Tao, Zewen Ying, Zhaoyang Cong, Zhen Qin, Zhenhua Fan, Zhihang Yu, Zhuo Jiang, Zijia Wu.

The abstract from the paper is the following:

*We introduce MiniMax-01 series, including MiniMax-Text-01 and MiniMax-VL-01, which are comparable to top-tier models while offering superior capabilities in processing longer contexts. The core lies in lightning attention and its efficient scaling. To maximize computational capacity, we integrate it with Mixture of Experts (MoE), creating a model with 32 experts and 456 billion total parameters, of which 45.9 billion are activated for each token. We develop an optimized parallel strategy and highly efficient computation-communication overlap techniques for MoE and lightning attention. This approach enables us to conduct efficient training and inference on models with hundreds of billions of parameters across contexts spanning millions of tokens. The context window of MiniMax-Text-01 can reach up to 1 million tokens during training and extrapolate to 4 million tokens during inference at an affordable cost. Our vision-language model, MiniMax-VL-01 is built through continued training with 512 billion vision-language tokens. Experiments on both standard and in-house benchmarks show that our models match the performance of state-of-the-art models like GPT-4o and Claude-3.5-Sonnet while offering 20-32 times longer context window.*

<!-- TODO: upload this image at https://huggingface.co/datasets/huggingface/documentation-images -->
<img src="https://raw.githubusercontent.com/MiniMax-AI/MiniMax-01/main/figures/TextBench.png"
alt="drawing" width="600"/>

<small> Text benchmark for MiniMax. Taken from the <a href="https://github.com/MiniMax-AI/MiniMax-01" target="_blank">official code</a>. </small>

This model was contributed by [Armaghan](https://github.com/geetu040) and [Shakib](https://github.com/Shakib-IO). The original code can be found [here](https://huggingface.co/MiniMaxAI/MiniMax-Text-01/tree/main).

### Architectural details

MiniMax is a powerful language model with 456 billion total parameters, of which 45.9 billion are activated per token. To better unlock the long context capabilities of the model, MiniMax adopts a hybrid architecture that combines Lightning Attention, Softmax Attention and Mixture-of-Experts (MoE). Leveraging advanced parallel strategies and innovative compute-communication overlap methods—such as Linear Attention Sequence Parallelism Plus (LASP+), varlen ring attention, Expert Tensor Parallel (ETP), etc., MiniMax's training context length is extended to 1 million tokens, and it can handle a context of up to 4 million tokens during the inference. On various academic benchmarks, MiniMax also demonstrates the performance of a top-tier model.

The architecture of MiniMax is briefly described as follows:

- Total Parameters: 456B
- Activated Parameters per Token: 45.9B
- Number Layers: 80
- Hybrid Attention: a softmax attention is positioned after every 7 lightning attention.
- Number of attention heads: 64
- Attention head dimension: 128
- Mixture of Experts:
- Number of experts: 32
- Expert hidden dimension: 9216
- Top-2 routing strategy
- Positional Encoding: Rotary Position Embedding (RoPE) applied to half of the attention head dimension with a base frequency of 10,000,000
- Hidden Size: 6144
- Vocab Size: 200,064

For more details refer to the [release blog post](https://www.minimaxi.com/en/news/minimax-01-series-2).

### License

`MiniMax` is released under the MINIMAX MODEL LICENSE AGREEMENT.

## Usage tips

The pre-trained model can be used as follows:

```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer

>>> model = AutoModelForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01", device_map="auto")
>>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01")

>>> messages = [
... {"role": "user", "content": "What is your favourite condiment?"},
... {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
... {"role": "user", "content": "Do you have mayonnaise recipes?"}
... ]

>>> model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")

>>> generated_ids = model.generate(model_inputs, max_new_tokens=100, do_sample=True)
>>> tokenizer.batch_decode(generated_ids)[0]
"Mayonnaise can be made as follows: (...)"
```

As can be seen, the instruction-tuned model requires a [chat template](../chat_templating) to be applied to make sure the inputs are prepared in the right format.

## Speeding up MiniMax by using Flash Attention

The code snippets above showcase inference without any optimization tricks. However, one can drastically speed up the model by leveraging [Flash Attention](../perf_train_gpu_one#flash-attention-2), which is a faster implementation of the attention mechanism used inside the model.

First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.

```bash
pip install -U flash-attn --no-build-isolation
```

Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of the [flash attention repository](https://github.com/Dao-AILab/flash-attention). Make also sure to load your model in half-precision (e.g. `torch.float16`)

To load and run a model using Flash Attention-2, refer to the snippet below:

```python
>>> import torch
>>> from transformers import AutoModelForCausalLM, AutoTokenizer

>>> model = AutoModelForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01", torch_dtype=torch.float16, attn_implementation="flash_attention_2", device_map="auto")
>>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01")

>>> prompt = "My favourite condiment is"

>>> model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
>>> model.to(device)

>>> generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
>>> tokenizer.batch_decode(generated_ids)[0]
"The expected output"
```

### Sliding window Attention

The current implementation supports the sliding window attention mechanism and memory efficient cache management.
To enable sliding window attention, just make sure to have a `flash-attn` version that is compatible with sliding window attention (`>=2.3.0`).

The Flash Attention-2 model uses also a more memory efficient cache slicing mechanism - as recommended per the official implementation of Mistral model that use rolling cache mechanism we keep the cache size fixed (`self.config.sliding_window`), support batched generation only for `padding_side="left"` and use the absolute position of the current token to compute the positional embedding.

## Shrinking down MiniMax using quantization

As the MiniMax model has 456 billion parameters, that would require about 912GB of GPU RAM in half precision (float16), since each parameter is stored in 2 bytes. However, one can shrink down the size of the model using [quantization](../quantization.md). If the model is quantized to 4 bits (or half a byte per parameter), about 228 GB of RAM is required.

Quantizing a model is as simple as passing a `quantization_config` to the model. Below, we'll leverage the bitsandbytes quantization library (but refer to [this page](../quantization.md) for alternative quantization methods):

```python
>>> import torch
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

>>> # specify how to quantize the model
>>> quantization_config = BitsAndBytesConfig(
... load_in_4bit=True,
... bnb_4bit_quant_type="nf4",
... bnb_4bit_compute_dtype="torch.float16",
... )

>>> model = AutoModelForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01", quantization_config=True, device_map="auto")
>>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01")

>>> prompt = "My favourite condiment is"

>>> messages = [
... {"role": "user", "content": "What is your favourite condiment?"},
... {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
... {"role": "user", "content": "Do you have mayonnaise recipes?"}
... ]

>>> model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")

>>> generated_ids = model.generate(model_inputs, max_new_tokens=100, do_sample=True)
>>> tokenizer.batch_decode(generated_ids)[0]
"The expected output"
```

This model was contributed by [geetu040](https://github.com/geetu040).
The original code can be found [here](https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/modeling_minimax.py).

## Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with MiniMax. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.

<PipelineTag pipeline="text-generation"/>

- The [Alignment Handbook](https://github.com/huggingface/alignment-handbook) by Hugging Face includes scripts and recipes to perform supervised fine-tuning (SFT) and direct preference optimization with Mistral-7B. This includes scripts for full fine-tuning, QLoRa on a single GPU as well as multi-GPU fine-tuning.
- [Causal language modeling task guide](../tasks/language_modeling)

## MiniMaxConfig

[[autodoc]] MiniMaxConfig

## MiniMaxModel

[[autodoc]] MiniMaxModel
- forward

## MiniMaxForCausalLM

[[autodoc]] MiniMaxForCausalLM
- forward

## MiniMaxForSequenceClassification

[[autodoc]] MiniMaxForSequenceClassification
- forward

## MiniMaxForTokenClassification

[[autodoc]] MiniMaxForTokenClassification
- forward

## MiniMaxForQuestionAnswering
[[autodoc]] MiniMaxForQuestionAnswering
- forward
1 change: 1 addition & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1976,6 +1976,7 @@ def _supports_default_dynamic_cache(self) -> bool:
and "jamba" not in self.__class__.__name__.lower()
and "zamba" not in self.__class__.__name__.lower()
and "bamba" not in self.__class__.__name__.lower()
and "minimax" not in self.__class__.__name__.lower()
)

def _prepare_cache_for_generation(
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@
from .megatron_gpt2 import *
from .mgp_str import *
from .mimi import *
from .minimax import *
from .mistral import *
from .mistral3 import *
from .mixtral import *
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@
("megatron-bert", "MegatronBertConfig"),
("mgp-str", "MgpstrConfig"),
("mimi", "MimiConfig"),
("minimax", "MiniMaxConfig"),
("mistral", "MistralConfig"),
("mistral3", "Mistral3Config"),
("mixtral", "MixtralConfig"),
Expand Down Expand Up @@ -584,6 +585,7 @@
("megatron_gpt2", "Megatron-GPT2"),
("mgp-str", "MGP-STR"),
("mimi", "Mimi"),
("minimax", "MiniMax"),
("mistral", "Mistral"),
("mistral3", "Mistral3"),
("mixtral", "Mixtral"),
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@
("megatron-bert", "MegatronBertModel"),
("mgp-str", "MgpstrForSceneTextRecognition"),
("mimi", "MimiModel"),
("minimax", "MiniMaxModel"),
("mistral", "MistralModel"),
("mistral3", "Mistral3Model"),
("mixtral", "MixtralModel"),
Expand Down Expand Up @@ -593,6 +594,7 @@
("mbart", "MBartForCausalLM"),
("mega", "MegaForCausalLM"),
("megatron-bert", "MegatronBertForCausalLM"),
("minimax", "MiniMaxForCausalLM"),
("mistral", "MistralForCausalLM"),
("mixtral", "MixtralForCausalLM"),
("mllama", "MllamaForCausalLM"),
Expand Down Expand Up @@ -1105,6 +1107,7 @@
("mbart", "MBartForSequenceClassification"),
("mega", "MegaForSequenceClassification"),
("megatron-bert", "MegatronBertForSequenceClassification"),
("minimax", "MiniMaxForSequenceClassification"),
("mistral", "MistralForSequenceClassification"),
("mixtral", "MixtralForSequenceClassification"),
("mobilebert", "MobileBertForSequenceClassification"),
Expand Down Expand Up @@ -1196,6 +1199,7 @@
("mbart", "MBartForQuestionAnswering"),
("mega", "MegaForQuestionAnswering"),
("megatron-bert", "MegatronBertForQuestionAnswering"),
("minimax", "MiniMaxForQuestionAnswering"),
("mistral", "MistralForQuestionAnswering"),
("mixtral", "MixtralForQuestionAnswering"),
("mobilebert", "MobileBertForQuestionAnswering"),
Expand Down Expand Up @@ -1302,6 +1306,7 @@
("markuplm", "MarkupLMForTokenClassification"),
("mega", "MegaForTokenClassification"),
("megatron-bert", "MegatronBertForTokenClassification"),
("minimax", "MiniMaxForTokenClassification"),
("mistral", "MistralForTokenClassification"),
("mixtral", "MixtralForTokenClassification"),
("mobilebert", "MobileBertForTokenClassification"),
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,13 @@
("mega", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("mgp-str", ("MgpstrTokenizer", None)),
(
"minimax",
(
"GPT2Tokenizer" if is_sentencepiece_available() else None,
"GPT2TokenizerFast" if is_tokenizers_available() else None,
),
),
(
"mistral",
(
Expand Down
29 changes: 29 additions & 0 deletions src/transformers/models/minimax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# coding=utf-8
# Copyright 2025 MiniMaxAI and HuggingFace Inc. teams. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_minimax import *
from .modeling_minimax import *
else:
import sys

_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Loading