Skip to content

Commit 71de20b

Browse files
Add Arcee model support (#38621)
* Add Arcee model support to transformers - Add ArceeConfig and model mappings for all task types (CausalLM, SequenceClassification, QuestionAnswering, TokenClassification) - Add auto-loading support through AutoModel, AutoConfig, and AutoTokenizer - Use LlamaTokenizer for tokenization - Add FX graph support for Arcee models - Create lazy loading module structure for Arcee * feat: update YARN scaling and RoPE validation for Arcee model * feat: add auto_docstring checkpoint config to Arcee model classes * docs: add pre-trained model weights reference to Arcee configuration files * refactor: move RoPE utilities to dedicated modeling_rope_utils module * Add comprehensive test suite for Arcee model - Add test_modeling_arcee.py following standard transformers test patterns - Include tests for all model variants (CausalLM, SequenceClassification, QuestionAnswering, TokenClassification) - Add specific test for ReLU² activation in ArceeMLP - Add RoPE scaling tests including YARN support - Follow CausalLMModelTest pattern used by similar models * Add documentation for Arcee model - Add comprehensive model documentation with usage examples - Include all model variants in autodoc - Add to table of contents in proper alphabetical order - Fixes documentation coverage for Arcee model classes * Make style/fixup * fix copyright year * Sync modular conversion * revert in legacy supported models in src/transformers/utils/fx * cleaned redundant code in modular_arcee.py * cleaned testing * removed pretraining tp * fix styles * integration testing --------- Co-authored-by: Pranav <[email protected]> Co-authored-by: Pranav <[email protected]>
1 parent 23c89a6 commit 71de20b

File tree

12 files changed

+1605
-0
lines changed

12 files changed

+1605
-0
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,8 @@
363363
- sections:
364364
- local: model_doc/albert
365365
title: ALBERT
366+
- local: model_doc/arcee
367+
title: Arcee
366368
- local: model_doc/bamba
367369
title: Bamba
368370
- local: model_doc/bart

docs/source/en/model_doc/arcee.md

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
<div style="float: right;">
18+
<div class="flex flex-wrap space-x-1">
19+
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
20+
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
21+
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
22+
</div>
23+
</div>
24+
25+
# Arcee
26+
27+
Arcee is a decoder-only transformer model based on the Llama architecture with a key modification: it uses ReLU² (ReLU-squared) activation in the MLP blocks instead of SiLU, following recent research showing improved training efficiency with squared activations. This architecture is designed for efficient training and inference while maintaining the proven stability of the Llama design.
28+
29+
The Arcee model is architecturally similar to Llama but uses `x * relu(x)` in MLP layers for improved gradient flow and is optimized for efficiency in both training and inference scenarios.
30+
31+
> [!TIP]
32+
> The Arcee model supports extended context with RoPE scaling and all standard transformers features including Flash Attention 2, SDPA, gradient checkpointing, and quantization support.
33+
34+
The example below demonstrates how to generate text with Arcee using [`Pipeline`] or the [`AutoModel`].
35+
36+
<hfoptions id="usage">
37+
<hfoption id="Pipeline">
38+
39+
```py
40+
import torch
41+
from transformers import pipeline
42+
43+
pipeline = pipeline(
44+
task="text-generation",
45+
model="arcee-ai/AFM-4.5B",
46+
torch_dtype=torch.float16,
47+
device=0
48+
)
49+
50+
output = pipeline("The key innovation in Arcee is")
51+
print(output[0]["generated_text"])
52+
```
53+
54+
</hfoption>
55+
<hfoption id="AutoModel">
56+
57+
```py
58+
import torch
59+
from transformers import AutoTokenizer, ArceeForCausalLM
60+
61+
tokenizer = AutoTokenizer.from_pretrained("arcee-ai/AFM-4.5B")
62+
model = ArceeForCausalLM.from_pretrained(
63+
"arcee-ai/AFM-4.5B",
64+
torch_dtype=torch.float16,
65+
device_map="auto"
66+
)
67+
68+
inputs = tokenizer("The key innovation in Arcee is", return_tensors="pt")
69+
with torch.no_grad():
70+
outputs = model.generate(**inputs, max_new_tokens=50)
71+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
72+
```
73+
74+
</hfoption>
75+
</hfoptions>
76+
77+
## ArceeConfig
78+
79+
[[autodoc]] ArceeConfig
80+
81+
## ArceeModel
82+
83+
[[autodoc]] ArceeModel
84+
- forward
85+
86+
## ArceeForCausalLM
87+
88+
[[autodoc]] ArceeForCausalLM
89+
- forward
90+
91+
## ArceeForSequenceClassification
92+
93+
[[autodoc]] ArceeForSequenceClassification
94+
- forward
95+
96+
## ArceeForQuestionAnswering
97+
98+
[[autodoc]] ArceeForQuestionAnswering
99+
- forward
100+
101+
## ArceeForTokenClassification
102+
103+
[[autodoc]] ArceeForTokenClassification
104+
- forward

src/transformers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .albert import *
2222
from .align import *
2323
from .altclip import *
24+
from .arcee import *
2425
from .aria import *
2526
from .audio_spectrogram_transformer import *
2627
from .auto import *
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import TYPE_CHECKING
15+
16+
from ...utils import _LazyModule
17+
from ...utils.import_utils import define_import_structure
18+
19+
20+
if TYPE_CHECKING:
21+
from .configuration_arcee import *
22+
from .modeling_arcee import *
23+
else:
24+
import sys
25+
26+
_file = globals()["__file__"]
27+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2+
# This file was automatically generated from src/transformers/models/arcee/modular_arcee.py.
3+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
4+
# the file from the modular. If any change should be done, please apply the change to the
5+
# modular_arcee.py file directly. One of our CI enforces this.
6+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7+
# coding=utf-8
8+
# Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
9+
#
10+
# Licensed under the Apache License, Version 2.0 (the "License");
11+
# you may not use this file except in compliance with the License.
12+
# You may obtain a copy of the License at
13+
#
14+
# http://www.apache.org/licenses/LICENSE-2.0
15+
#
16+
# Unless required by applicable law or agreed to in writing, software
17+
# distributed under the License is distributed on an "AS IS" BASIS,
18+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19+
# See the License for the specific language governing permissions and
20+
# limitations under the License.
21+
22+
from ...configuration_utils import PretrainedConfig
23+
from ...modeling_rope_utils import rope_config_validation
24+
25+
26+
class ArceeConfig(PretrainedConfig):
27+
r"""
28+
This is the configuration class to store the configuration of a [`ArceeModel`]. It is used to instantiate an Arcee
29+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
30+
defaults will yield a similar configuration to that of the AFM-4.5B-Base.
31+
32+
Pre-trained weights are available at
33+
[arcee-ai/AFM-4.5B](https://huggingface.co/arcee-ai/AFM-4.5B)
34+
and were used to build the examples below.
35+
36+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
37+
documentation from [`PretrainedConfig`] for more information.
38+
39+
Args:
40+
vocab_size (`int`, *optional*, defaults to 32000):
41+
Vocabulary size of the Arcee model. Defines the number of different tokens that can be represented by the
42+
`inputs_ids` passed when calling [`ArceeModel`]
43+
hidden_size (`int`, *optional*, defaults to 2560):
44+
Dimension of the hidden representations.
45+
intermediate_size (`int`, *optional*, defaults to 18432):
46+
Dimension of the MLP representations.
47+
num_hidden_layers (`int`, *optional*, defaults to 32):
48+
Number of hidden layers in the Transformer decoder.
49+
num_attention_heads (`int`, *optional*, defaults to 32):
50+
Number of attention heads for each attention layer in the Transformer decoder.
51+
num_key_value_heads (`int`, *optional*):
52+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
53+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
54+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
55+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
56+
by meanpooling all the original heads within that group. For more details checkout [this
57+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
58+
`num_attention_heads`.
59+
hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`):
60+
The non-linear activation function (function or string) in the decoder.
61+
max_position_embeddings (`int`, *optional*, defaults to 4096):
62+
The maximum sequence length that this model might ever be used with. AFM-4.5B-Base supports up to 16384 tokens.
63+
initializer_range (`float`, *optional*, defaults to 0.02):
64+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
65+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
66+
The epsilon used by the rms normalization layers.
67+
use_cache (`bool`, *optional*, defaults to `True`):
68+
Whether or not the model should return the last key/values attentions (not used by all models). Only
69+
relevant if `config.is_decoder=True`.
70+
pad_token_id (`int`, *optional*):
71+
Padding token id.
72+
bos_token_id (`int`, *optional*, defaults to 128000):
73+
Beginning of stream token id.
74+
eos_token_id (`int`, *optional*, defaults to 128001):
75+
End of stream token id.
76+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
77+
Whether to tie weight embeddings
78+
rope_theta (`float`, *optional*, defaults to 10000.0):
79+
The base period of the RoPE embeddings.
80+
rope_scaling (`Dict`, *optional*):
81+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
82+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
83+
accordingly.
84+
Expected contents:
85+
`rope_type` (`str`):
86+
The sub-variant of RoPE to use. Can be one of ['default', 'yarn'], with 'default' being the original RoPE implementation.
87+
`factor` (`float`, *optional*):
88+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
89+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
90+
original maximum pre-trained length.
91+
`original_max_position_embeddings` (`int`, *optional*):
92+
Used with 'yarn'. The original max position embeddings used during pretraining.
93+
`attention_factor` (`float`, *optional*):
94+
Used with 'yarn'. The scaling factor to be applied on the attention computation. If unspecified,
95+
it defaults to value recommended by the implementation, using the `factor` field to infer the suggested value.
96+
`beta_fast` (`float`, *optional*):
97+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
98+
ramp function. If unspecified, it defaults to 32.
99+
`beta_slow` (`float`, *optional*):
100+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
101+
ramp function. If unspecified, it defaults to 1.
102+
attention_bias (`bool`, *optional*, defaults to `False`):
103+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
104+
attention_dropout (`float`, *optional*, defaults to 0.0):
105+
The dropout ratio for the attention probabilities.
106+
mlp_bias (`bool`, *optional*, defaults to `False`):
107+
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
108+
head_dim (`int`, *optional*):
109+
The attention head dimension. If None, it will default to hidden_size // num_attention_heads
110+
111+
```python
112+
>>> from transformers import ArceeModel, ArceeConfig
113+
114+
>>> # Initializing an Arcee AFM-4.5B-Base style configuration
115+
>>> configuration = ArceeConfig()
116+
117+
>>> # Initializing a model from the AFM-4.5B-Base style configuration
118+
>>> model = ArceeModel(configuration)
119+
120+
>>> # Accessing the model configuration
121+
>>> configuration = model.config
122+
```"""
123+
124+
model_type = "arcee"
125+
keys_to_ignore_at_inference = ["past_key_values"]
126+
base_model_tp_plan = {
127+
"layers.*.self_attn.q_proj": "colwise",
128+
"layers.*.self_attn.k_proj": "colwise",
129+
"layers.*.self_attn.v_proj": "colwise",
130+
"layers.*.self_attn.o_proj": "rowwise",
131+
"layers.*.mlp.gate_proj": "colwise",
132+
"layers.*.mlp.up_proj": "colwise",
133+
"layers.*.mlp.down_proj": "rowwise",
134+
}
135+
base_model_pp_plan = {
136+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
137+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
138+
"norm": (["hidden_states"], ["hidden_states"]),
139+
}
140+
141+
def __init__(
142+
self,
143+
vocab_size=32000,
144+
hidden_size=2560,
145+
intermediate_size=18432,
146+
num_hidden_layers=32,
147+
num_attention_heads=32,
148+
num_key_value_heads=None,
149+
hidden_act="relu2",
150+
max_position_embeddings=4096,
151+
initializer_range=0.02,
152+
rms_norm_eps=1e-5,
153+
use_cache=True,
154+
pad_token_id=None,
155+
bos_token_id=128000,
156+
eos_token_id=128001,
157+
tie_word_embeddings=False,
158+
rope_theta=10000.0,
159+
rope_scaling=None,
160+
attention_bias=False,
161+
attention_dropout=0.0,
162+
mlp_bias=False,
163+
head_dim=None,
164+
**kwargs,
165+
):
166+
super().__init__(
167+
pad_token_id=pad_token_id,
168+
bos_token_id=bos_token_id,
169+
eos_token_id=eos_token_id,
170+
tie_word_embeddings=tie_word_embeddings,
171+
**kwargs,
172+
)
173+
self.vocab_size = vocab_size
174+
self.max_position_embeddings = max_position_embeddings
175+
self.hidden_size = hidden_size
176+
self.intermediate_size = intermediate_size
177+
self.num_hidden_layers = num_hidden_layers
178+
self.num_attention_heads = num_attention_heads
179+
180+
# for backward compatibility
181+
if num_key_value_heads is None:
182+
num_key_value_heads = num_attention_heads
183+
184+
self.num_key_value_heads = num_key_value_heads
185+
self.hidden_act = hidden_act
186+
self.initializer_range = initializer_range
187+
self.rms_norm_eps = rms_norm_eps
188+
self.use_cache = use_cache
189+
self.rope_theta = rope_theta
190+
self.rope_scaling = rope_scaling
191+
self.attention_bias = attention_bias
192+
self.attention_dropout = attention_dropout
193+
self.mlp_bias = mlp_bias
194+
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
195+
# Validate the correctness of rotary position embeddings parameters
196+
# BC: if there is a 'type' field, copy it it to 'rope_type'.
197+
if self.rope_scaling is not None and "type" in self.rope_scaling:
198+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
199+
rope_config_validation(self)
200+
201+
202+
__all__ = ["ArceeConfig"]

0 commit comments

Comments
 (0)