Skip to content

Commit ad590be

Browse files
huydt84huydt-btiggerganov
authored
model : add NeoBERT (#14164)
* convert neobert model to gguf * add inference graph * fix flake8 lint * followed reviewer suggestions Co-authored-by: Georgi Gerganov <[email protected]> * follow reviewers suggestions Co-authored-by: Georgi Gerganov <[email protected]> * override NeoBERT feed-forward length --------- Co-authored-by: dinhhuy <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 7d6d91b commit ad590be

File tree

6 files changed

+222
-1
lines changed

6 files changed

+222
-1
lines changed

convert_hf_to_gguf.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ def prepare_metadata(self, vocab_only: bool):
519519
def set_gguf_parameters(self):
520520
self.gguf_writer.add_block_count(self.block_count)
521521

522-
if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx", "n_positions"], optional=True)) is not None:
522+
if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx", "n_positions", "max_length"], optional=True)) is not None:
523523
self.gguf_writer.add_context_length(n_ctx)
524524
logger.info(f"gguf: context length = {n_ctx}")
525525

@@ -4076,6 +4076,34 @@ def _is_tokenizer_xlmroberta(self) -> bool:
40764076
raise ValueError(f"unknown tokenizer: {toktyp}")
40774077

40784078

4079+
@ModelBase.register("NeoBERT", "NeoBERTLMHead", "NeoBERTForSequenceClassification")
4080+
class NeoBert(BertModel):
4081+
model_arch = gguf.MODEL_ARCH.NEO_BERT
4082+
4083+
def set_gguf_parameters(self):
4084+
super().set_gguf_parameters()
4085+
4086+
# NeoBERT uses 2/3 of the intermediate size as feed forward length
4087+
self.gguf_writer.add_feed_forward_length(int(2 * self.hparams["intermediate_size"] / 3))
4088+
self.gguf_writer.add_rope_freq_base(10000.0) # default value for NeoBERT
4089+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
4090+
4091+
f_rms_eps = self.hparams.get("norm_eps", 1e-6) # default value for NeoBERT
4092+
self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
4093+
logger.info(f"gguf: rms norm epsilon = {f_rms_eps}")
4094+
4095+
self.gguf_writer.add_pooling_type(gguf.PoolingType.CLS) # https://huggingface.co/chandar-lab/NeoBERT#how-to-use
4096+
4097+
def modify_tensors(self, data_torch, name, bid):
4098+
if name.startswith("decoder."):
4099+
return []
4100+
4101+
if name.startswith("model."):
4102+
name = name[6:]
4103+
4104+
return super().modify_tensors(data_torch, name, bid)
4105+
4106+
40794107
@ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
40804108
class XLMRobertaModel(BertModel):
40814109
model_arch = gguf.MODEL_ARCH.BERT

gguf-py/gguf/constants.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ class MODEL_ARCH(IntEnum):
291291
BERT = auto()
292292
NOMIC_BERT = auto()
293293
NOMIC_BERT_MOE = auto()
294+
NEO_BERT = auto()
294295
JINA_BERT_V2 = auto()
295296
BLOOM = auto()
296297
STABLELM = auto()
@@ -573,6 +574,7 @@ class MODEL_TENSOR(IntEnum):
573574
MODEL_ARCH.BERT: "bert",
574575
MODEL_ARCH.NOMIC_BERT: "nomic-bert",
575576
MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe",
577+
MODEL_ARCH.NEO_BERT: "neo-bert",
576578
MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2",
577579
MODEL_ARCH.BLOOM: "bloom",
578580
MODEL_ARCH.STABLELM: "stablelm",
@@ -1081,6 +1083,18 @@ class MODEL_TENSOR(IntEnum):
10811083
MODEL_TENSOR.FFN_UP_EXP,
10821084
MODEL_TENSOR.LAYER_OUT_NORM,
10831085
],
1086+
MODEL_ARCH.NEO_BERT: [
1087+
MODEL_TENSOR.TOKEN_EMBD,
1088+
MODEL_TENSOR.ATTN_NORM,
1089+
MODEL_TENSOR.ATTN_QKV,
1090+
MODEL_TENSOR.ATTN_OUT,
1091+
MODEL_TENSOR.FFN_NORM,
1092+
MODEL_TENSOR.FFN_DOWN,
1093+
MODEL_TENSOR.FFN_UP,
1094+
MODEL_TENSOR.ENC_OUTPUT_NORM,
1095+
MODEL_TENSOR.CLS,
1096+
MODEL_TENSOR.CLS_OUT,
1097+
],
10841098
MODEL_ARCH.JINA_BERT_V2: [
10851099
MODEL_TENSOR.TOKEN_EMBD,
10861100
MODEL_TENSOR.TOKEN_EMBD_NORM,

gguf-py/gguf/tensor_mapping.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class TensorNameMap:
3131
"model.embeddings", # rwkv7
3232
"model.word_embeddings", # bailingmoe
3333
"language_model.model.embed_tokens", # llama4
34+
"encoder", # neobert
3435
),
3536

3637
# Token type embeddings
@@ -134,6 +135,7 @@ class TensorNameMap:
134135
"rwkv.blocks.{bid}.ln1", # rwkv6
135136
"model.layers.{bid}.ln1", # rwkv7
136137
"model.layers.{bid}.input_layernorm", # llama4
138+
"transformer_encoder.{bid}.attention_norm", # neobert
137139
),
138140

139141
# Attention norm 2
@@ -161,6 +163,7 @@ class TensorNameMap:
161163
"model.layers.{bid}.self_attn.qkv_proj", # phi3
162164
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
163165
"transformer.layers.{bid}.attn.qkv_proj", # openelm
166+
"transformer_encoder.{bid}.qkv", # neobert
164167
),
165168

166169
# Attention query
@@ -236,6 +239,7 @@ class TensorNameMap:
236239
"transformer.layers.{bid}.attn.out_proj", # openelm
237240
"transformer.h.{bid}.attn.attention.out_proj", # exaone
238241
"model.layers.{bid}.self_attn.o_proj", # llama4
242+
"transformer_encoder.{bid}.wo", # neobert
239243
),
240244

241245
# Attention output norm
@@ -276,6 +280,7 @@ class TensorNameMap:
276280
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
277281
"transformer.layers.{bid}.ffn_norm", # openelm
278282
"model.layers.{bid}.post_attention_layernorm", # llama4
283+
"transformer_encoder.{bid}.ffn_norm", # neobert
279284
),
280285

281286
# Post feed-forward norm
@@ -340,6 +345,7 @@ class TensorNameMap:
340345
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
341346
"transformer.h.{bid}.mlp.c_fc_1", # exaone
342347
"model.layers.{bid}.feed_forward.up_proj", # llama4
348+
"transformer_encoder.{bid}.ffn.w12", # neobert
343349
),
344350

345351
MODEL_TENSOR.FFN_UP_EXP: (
@@ -422,6 +428,7 @@ class TensorNameMap:
422428
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
423429
"model.layers.h.{bid}.mlp.c_proj", # exaone
424430
"model.layers.{bid}.feed_forward.down_proj", # llama4
431+
"transformer_encoder.{bid}.ffn.w3", # neobert
425432
),
426433

427434
MODEL_TENSOR.FFN_DOWN_EXP: (
@@ -832,12 +839,14 @@ class TensorNameMap:
832839
# TODO: these do not belong to block_mappings_cfg - move them to mappings_cfg
833840
MODEL_TENSOR.ENC_OUTPUT_NORM: (
834841
"encoder.final_layer_norm", # t5
842+
"layer_norm", # neobert
835843
),
836844

837845
MODEL_TENSOR.CLS: (
838846
"classifier", # jina
839847
"classifier.dense", # roberta
840848
"pre_classifier", # distillbert
849+
"dense", # neobert
841850
),
842851

843852
MODEL_TENSOR.CLS_OUT: (

src/llama-arch.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
2020
{ LLM_ARCH_BERT, "bert" },
2121
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
2222
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
23+
{ LLM_ARCH_NEO_BERT, "neo-bert" },
2324
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
2425
{ LLM_ARCH_BLOOM, "bloom" },
2526
{ LLM_ARCH_STABLELM, "stablelm" },
@@ -514,6 +515,21 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
514515
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
515516
},
516517
},
518+
{
519+
LLM_ARCH_NEO_BERT,
520+
{
521+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
522+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
523+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
524+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
525+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
526+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
527+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
528+
{ LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" },
529+
{ LLM_TENSOR_CLS, "cls" },
530+
{ LLM_TENSOR_CLS_OUT, "cls.output" },
531+
},
532+
},
517533
{
518534
LLM_ARCH_JINA_BERT_V2,
519535
{

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ enum llm_arch {
2424
LLM_ARCH_BERT,
2525
LLM_ARCH_NOMIC_BERT,
2626
LLM_ARCH_NOMIC_BERT_MOE,
27+
LLM_ARCH_NEO_BERT,
2728
LLM_ARCH_JINA_BERT_V2,
2829
LLM_ARCH_BLOOM,
2930
LLM_ARCH_STABLELM,

src/llama-model.cpp

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
749749
}
750750
}
751751
} break;
752+
case LLM_ARCH_NEO_BERT:
753+
{
754+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
755+
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
756+
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
757+
758+
if (hparams.n_layer == 28) {
759+
type = LLM_TYPE_250M;
760+
}
761+
} break;
752762
case LLM_ARCH_BLOOM:
753763
{
754764
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -2212,6 +2222,32 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
22122222
layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0);
22132223
}
22142224
} break;
2225+
case LLM_ARCH_NEO_BERT:
2226+
{
2227+
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
2228+
2229+
cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED);
2230+
cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED);
2231+
2232+
cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
2233+
cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
2234+
2235+
output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0);
2236+
2237+
for (int i = 0; i < n_layer; ++i) {
2238+
auto & layer = layers[i];
2239+
2240+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
2241+
2242+
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
2243+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
2244+
2245+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
2246+
2247+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff*2}, 0);
2248+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
2249+
}
2250+
} break;
22152251
case LLM_ARCH_JINA_BERT_V2:
22162252
{
22172253
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings
@@ -6182,6 +6218,117 @@ struct llm_build_bert : public llm_graph_context {
61826218
}
61836219
};
61846220

6221+
struct llm_build_neo_bert : public llm_graph_context {
6222+
llm_build_neo_bert(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
6223+
const int64_t n_embd_head = hparams.n_embd_head_v;
6224+
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
6225+
6226+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
6227+
6228+
ggml_tensor * cur;
6229+
ggml_tensor * inpL;
6230+
ggml_tensor * inp_pos = build_inp_pos();
6231+
6232+
// construct input embeddings (token, type, position)
6233+
inpL = build_inp_embd(model.tok_embd);
6234+
cb(inpL, "inp_embd", -1);
6235+
6236+
auto * inp_attn = build_attn_inp_no_cache();
6237+
6238+
// iterate layers
6239+
for (int il = 0; il < n_layer; ++il) {
6240+
ggml_tensor * cur = inpL;
6241+
6242+
ggml_tensor * Qcur;
6243+
ggml_tensor * Kcur;
6244+
ggml_tensor * Vcur;
6245+
6246+
// pre-norm
6247+
cur = build_norm(inpL,
6248+
model.layers[il].attn_norm, NULL,
6249+
LLM_NORM_RMS, il);
6250+
6251+
// self-attention
6252+
cur = build_lora_mm(model.layers[il].wqkv, cur);
6253+
cb(cur, "wqkv", il);
6254+
6255+
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
6256+
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
6257+
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
6258+
6259+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
6260+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
6261+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
6262+
6263+
// RoPE
6264+
Qcur = ggml_rope_ext(
6265+
ctx0, Qcur, inp_pos, nullptr,
6266+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6267+
ext_factor, attn_factor, beta_fast, beta_slow
6268+
);
6269+
6270+
Kcur = ggml_rope_ext(
6271+
ctx0, Kcur, inp_pos, nullptr,
6272+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6273+
ext_factor, attn_factor, beta_fast, beta_slow
6274+
);
6275+
6276+
cb(Qcur, "Qcur", il);
6277+
cb(Kcur, "Kcur", il);
6278+
cb(Vcur, "Vcur", il);
6279+
6280+
cur = build_attn(inp_attn, gf,
6281+
model.layers[il].wo, nullptr,
6282+
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6283+
cb(cur, "kqv_out", il);
6284+
6285+
if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
6286+
// skip computing output for unused tokens
6287+
ggml_tensor * inp_out_ids = build_inp_out_ids();
6288+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
6289+
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
6290+
}
6291+
6292+
// re-add the layer input
6293+
cur = ggml_add(ctx0, cur, inpL);
6294+
6295+
ggml_tensor * ffn_inp = cur;
6296+
cb(ffn_inp, "ffn_inp", il);
6297+
6298+
// pre-norm
6299+
cur = build_norm(ffn_inp,
6300+
model.layers[il].ffn_norm, NULL,
6301+
LLM_NORM_RMS, il);
6302+
cb(cur, "ffn_norm", il);
6303+
6304+
// feed-forward network
6305+
cur = build_ffn(cur,
6306+
model.layers[il].ffn_up,
6307+
NULL, NULL, NULL, NULL, NULL,
6308+
model.layers[il].ffn_down,
6309+
NULL, NULL, NULL,
6310+
LLM_FFN_SWIGLU, LLM_FFN_SEQ, il);
6311+
6312+
// attentions bypass the intermediate layer
6313+
cur = ggml_add(ctx0, cur, ffn_inp);
6314+
6315+
// input for next layer
6316+
inpL = cur;
6317+
}
6318+
6319+
cur = inpL;
6320+
6321+
cur = build_norm(cur,
6322+
model.output_norm_enc, NULL,
6323+
LLM_NORM_RMS, -1);
6324+
6325+
cb(cur, "result_embd", -1);
6326+
res->t_embd = cur;
6327+
6328+
ggml_build_forward_expand(gf, cur);
6329+
}
6330+
};
6331+
61856332
struct llm_build_bloom : public llm_graph_context {
61866333
llm_build_bloom(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
61876334
const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -13595,6 +13742,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1359513742
case LLM_ARCH_JINA_BERT_V2:
1359613743
case LLM_ARCH_NOMIC_BERT:
1359713744
case LLM_ARCH_NOMIC_BERT_MOE:
13745+
case LLM_ARCH_NEO_BERT:
1359813746
case LLM_ARCH_WAVTOKENIZER_DEC:
1359913747
{
1360013748
res = nullptr;
@@ -13703,6 +13851,10 @@ llm_graph_result_ptr llama_model::build_graph(
1370313851
{
1370413852
llm = std::make_unique<llm_build_bert>(*this, params, gf);
1370513853
} break;
13854+
case LLM_ARCH_NEO_BERT:
13855+
{
13856+
llm = std::make_unique<llm_build_neo_bert>(*this, params, gf);
13857+
} break;
1370613858
case LLM_ARCH_BLOOM:
1370713859
{
1370813860
llm = std::make_unique<llm_build_bloom>(*this, params, gf);
@@ -14082,6 +14234,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
1408214234
case LLM_ARCH_GRANITE_MOE:
1408314235
case LLM_ARCH_CHAMELEON:
1408414236
case LLM_ARCH_BAILINGMOE:
14237+
case LLM_ARCH_NEO_BERT:
1408514238
case LLM_ARCH_ARCEE:
1408614239
return LLAMA_ROPE_TYPE_NORM;
1408714240

0 commit comments

Comments
 (0)