Skip to content

Commit 4356d96

Browse files
fairydreamingsszymczy
authored andcommitted
Add missing inference support for GPTNeoXForCausalLM (Pythia and GPT-NeoX base models) (ggml-org#7461)
* convert-hf : add conversion of bloom-style qkv tensor to gpt-style qkv (code borrowed from BloomModel) * llama : add inference support for LLM_ARCH_GPTNEOX * llama : add model types for every Pythia variant and GPT-NeoX Co-authored-by: Stanisław Szymczyk <[email protected]>
1 parent 9f473c5 commit 4356d96

File tree

2 files changed

+273
-1
lines changed

2 files changed

+273
-1
lines changed

convert-hf-to-gguf.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,44 @@ def set_gguf_parameters(self):
622622
self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
623623
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
624624

625+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
626+
del bid # unused
627+
628+
n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
629+
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
630+
631+
tensors: list[tuple[str, Tensor]] = []
632+
633+
if re.match(r"gpt_neox\.layers\.\d+\.attention\.query_key_value\.weight", name):
634+
# Map bloom-style qkv_linear to gpt-style qkv_linear
635+
# bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa
636+
# gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa
637+
qkv_weights = data_torch.reshape((n_head, 3, n_embed // n_head, n_embed))
638+
data_torch = torch.cat(
639+
(
640+
qkv_weights[:, 0, :, :].reshape((-1, n_embed)),
641+
qkv_weights[:, 1, :, :].reshape((-1, n_embed)),
642+
qkv_weights[:, 2, :, :].reshape((-1, n_embed)),
643+
),
644+
dim=0,
645+
)
646+
logger.info("re-format attention.linear_qkv.weight")
647+
elif re.match(r"gpt_neox\.layers\.\d+\.attention\.query_key_value\.bias", name):
648+
qkv_bias = data_torch.reshape((n_head, 3, n_embed // n_head))
649+
data_torch = torch.cat(
650+
(
651+
qkv_bias[:, 0, :].reshape((n_embed,)),
652+
qkv_bias[:, 1, :].reshape((n_embed,)),
653+
qkv_bias[:, 2, :].reshape((n_embed,)),
654+
),
655+
dim=0,
656+
)
657+
logger.info("re-format attention.linear_qkv.bias")
658+
659+
tensors.append((self.map_tensor_name(name), data_torch))
660+
661+
return tensors
662+
625663

626664
@Model.register("BloomForCausalLM")
627665
class BloomModel(Model):

llama.cpp

Lines changed: 235 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1696,17 +1696,24 @@ static llama_state g_state;
16961696
// available llama models
16971697
enum e_model {
16981698
MODEL_UNKNOWN,
1699+
MODEL_14M,
16991700
MODEL_17M,
17001701
MODEL_22M,
17011702
MODEL_33M,
1703+
MODEL_70M,
17021704
MODEL_109M,
17031705
MODEL_137M,
1706+
MODEL_160M,
17041707
MODEL_335M,
1708+
MODEL_410M,
17051709
MODEL_0_5B,
17061710
MODEL_1B,
1711+
MODEL_1_4B,
17071712
MODEL_2B,
1713+
MODEL_2_8B,
17081714
MODEL_3B,
17091715
MODEL_4B,
1716+
MODEL_6_9B,
17101717
MODEL_7B,
17111718
MODEL_8B,
17121719
MODEL_12B,
@@ -1738,6 +1745,7 @@ static const size_t GiB = 1024*MiB;
17381745
struct llama_hparams {
17391746
bool vocab_only;
17401747
bool rope_finetuned;
1748+
bool use_par_res;
17411749

17421750
uint32_t n_vocab;
17431751
uint32_t n_ctx_train; // context size the model was trained on
@@ -3777,17 +3785,24 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
37773785

37783786
static const char * llama_model_type_name(e_model type) {
37793787
switch (type) {
3788+
case MODEL_14M: return "14M";
37803789
case MODEL_17M: return "17M";
37813790
case MODEL_22M: return "22M";
37823791
case MODEL_33M: return "33M";
3792+
case MODEL_70M: return "70M";
37833793
case MODEL_109M: return "109M";
37843794
case MODEL_137M: return "137M";
3795+
case MODEL_160M: return "160M";
37853796
case MODEL_335M: return "335M";
3797+
case MODEL_410M: return "410M";
37863798
case MODEL_0_5B: return "0.5B";
37873799
case MODEL_1B: return "1B";
3800+
case MODEL_1_4B: return "1.4B";
37883801
case MODEL_2B: return "2B";
3802+
case MODEL_2_8B: return "2.8B";
37893803
case MODEL_3B: return "3B";
37903804
case MODEL_4B: return "4B";
3805+
case MODEL_6_9B: return "6.9B";
37913806
case MODEL_7B: return "7B";
37923807
case MODEL_8B: return "8B";
37933808
case MODEL_12B: return "12B";
@@ -4286,6 +4301,52 @@ static void llm_load_hparams(
42864301
default: model.type = e_model::MODEL_UNKNOWN;
42874302
}
42884303
} break;
4304+
case LLM_ARCH_GPTNEOX:
4305+
{
4306+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
4307+
ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res);
4308+
switch (hparams.n_layer) {
4309+
case 6:
4310+
switch (hparams.n_ff) {
4311+
case 512: model.type = e_model::MODEL_14M; break;
4312+
case 2048: model.type = e_model::MODEL_70M; break;
4313+
default: model.type = e_model::MODEL_UNKNOWN;
4314+
} break;
4315+
case 12:
4316+
switch (hparams.n_ff) {
4317+
case 3072: model.type = e_model::MODEL_160M; break;
4318+
default: model.type = e_model::MODEL_UNKNOWN;
4319+
} break;
4320+
case 16:
4321+
switch (hparams.n_ff) {
4322+
case 8192: model.type = e_model::MODEL_1B; break;
4323+
default: model.type = e_model::MODEL_UNKNOWN;
4324+
} break;
4325+
case 24:
4326+
switch (hparams.n_ff) {
4327+
case 4096: model.type = e_model::MODEL_410M; break;
4328+
case 8192: model.type = e_model::MODEL_1_4B; break;
4329+
default: model.type = e_model::MODEL_UNKNOWN;
4330+
} break;
4331+
case 32:
4332+
switch (hparams.n_ff) {
4333+
case 10240: model.type = e_model::MODEL_2_8B; break;
4334+
case 16384: model.type = e_model::MODEL_6_9B; break;
4335+
default: model.type = e_model::MODEL_UNKNOWN;
4336+
} break;
4337+
case 36:
4338+
switch (hparams.n_ff) {
4339+
case 20480: model.type = e_model::MODEL_12B; break;
4340+
default: model.type = e_model::MODEL_UNKNOWN;
4341+
} break;
4342+
case 44:
4343+
switch (hparams.n_ff) {
4344+
case 24576: model.type = e_model::MODEL_20B; break;
4345+
default: model.type = e_model::MODEL_UNKNOWN;
4346+
} break;
4347+
default: model.type = e_model::MODEL_UNKNOWN;
4348+
}
4349+
} break;
42894350
default: (void)0;
42904351
}
42914352

@@ -6037,6 +6098,41 @@ static bool llm_load_tensors(
60376098
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
60386099
}
60396100
} break;
6101+
case LLM_ARCH_GPTNEOX:
6102+
{
6103+
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
6104+
// output
6105+
{
6106+
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
6107+
model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
6108+
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
6109+
}
6110+
6111+
for (int i = 0; i < n_layer; ++i) {
6112+
ggml_context * ctx_layer = ctx_for_layer(i);
6113+
ggml_context * ctx_split = ctx_for_layer_split(i);
6114+
6115+
auto & layer = model.layers[i];
6116+
6117+
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
6118+
layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd});
6119+
6120+
layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
6121+
layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa});
6122+
6123+
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
6124+
layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd});
6125+
6126+
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
6127+
layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd});
6128+
6129+
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
6130+
layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd});
6131+
6132+
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
6133+
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff});
6134+
}
6135+
} break;
60406136
default:
60416137
throw std::runtime_error("unknown architecture");
60426138
}
@@ -10564,6 +10660,140 @@ struct llm_build_context {
1056410660

1056510661
return gf;
1056610662
}
10663+
10664+
struct ggml_cgraph * build_gptneox() {
10665+
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
10666+
10667+
const int64_t n_embd_head = hparams.n_embd_head_v;
10668+
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
10669+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
10670+
10671+
struct ggml_tensor * cur;
10672+
struct ggml_tensor * inpL;
10673+
10674+
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
10675+
10676+
// inp_pos - contains the positions
10677+
struct ggml_tensor * inp_pos = build_inp_pos();
10678+
10679+
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
10680+
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
10681+
10682+
for (int il = 0; il < n_layer; ++il) {
10683+
cur = llm_build_norm(ctx0, inpL, hparams,
10684+
model.layers[il].attn_norm,
10685+
model.layers[il].attn_norm_b,
10686+
LLM_NORM, cb, il);
10687+
cb(cur, "attn_norm", il);
10688+
10689+
// self-attention
10690+
{
10691+
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
10692+
cb(cur, "wqkv", il);
10693+
10694+
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
10695+
cb(cur, "bqkv", il);
10696+
10697+
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
10698+
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
10699+
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
10700+
10701+
cb(Qcur, "Qcur", il);
10702+
cb(Kcur, "Kcur", il);
10703+
cb(Vcur, "Vcur", il);
10704+
10705+
Qcur = ggml_rope_ext(
10706+
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
10707+
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
10708+
ext_factor, attn_factor, beta_fast, beta_slow
10709+
);
10710+
cb(Qcur, "Qcur", il);
10711+
10712+
Kcur = ggml_rope_ext(
10713+
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
10714+
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
10715+
ext_factor, attn_factor, beta_fast, beta_slow
10716+
);
10717+
cb(Kcur, "Kcur", il);
10718+
10719+
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
10720+
model.layers[il].wo, model.layers[il].bo,
10721+
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
10722+
}
10723+
10724+
if (il == n_layer - 1) {
10725+
// skip computing output for unused tokens
10726+
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
10727+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10728+
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
10729+
}
10730+
10731+
// ffn
10732+
if (hparams.use_par_res) {
10733+
// attention and ffn are computed in parallel
10734+
// x = x + attn(ln1(x)) + ffn(ln2(x))
10735+
10736+
struct ggml_tensor * attn_out = cur;
10737+
10738+
cur = llm_build_norm(ctx0, inpL, hparams,
10739+
model.layers[il].ffn_norm,
10740+
model.layers[il].ffn_norm_b,
10741+
LLM_NORM, cb, il);
10742+
cb(cur, "ffn_norm", il);
10743+
10744+
cur = llm_build_ffn(ctx0, cur,
10745+
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
10746+
NULL, NULL,
10747+
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
10748+
NULL,
10749+
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
10750+
cb(cur, "ffn_out", il);
10751+
10752+
cur = ggml_add(ctx0, cur, inpL);
10753+
cb(cur, "ffn_out", il);
10754+
10755+
inpL = ggml_add(ctx0, cur, attn_out);
10756+
cb(inpL, "l_out", il);
10757+
} else {
10758+
// attention and ffn are computed sequentially
10759+
// x = x + attn(ln1(x))
10760+
// x = x + ffn(ln2(x))
10761+
10762+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
10763+
cb(ffn_inp, "ffn_inp", il);
10764+
10765+
cur = llm_build_norm(ctx0, ffn_inp, hparams,
10766+
model.layers[il].ffn_norm,
10767+
model.layers[il].ffn_norm_b,
10768+
LLM_NORM, cb, il);
10769+
cb(cur, "ffn_norm", il);
10770+
10771+
cur = llm_build_ffn(ctx0, cur,
10772+
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
10773+
NULL, NULL,
10774+
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
10775+
NULL,
10776+
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
10777+
cb(cur, "ffn_out", il);
10778+
10779+
inpL = ggml_add(ctx0, cur, ffn_inp);
10780+
cb(inpL, "l_out", il);
10781+
}
10782+
}
10783+
10784+
cur = llm_build_norm(ctx0, inpL, hparams,
10785+
model.output_norm,
10786+
model.output_norm_b,
10787+
LLM_NORM, cb, -1);
10788+
cb(cur, "result_norm", -1);
10789+
10790+
cur = ggml_mul_mat(ctx0, model.output, cur);
10791+
cb(cur, "result_output", -1);
10792+
10793+
ggml_build_forward_expand(gf, cur);
10794+
10795+
return gf;
10796+
}
1056710797
};
1056810798

1056910799
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@@ -10774,6 +11004,10 @@ static struct ggml_cgraph * llama_build_graph(
1077411004
{
1077511005
result = llm.build_olmo();
1077611006
} break;
11007+
case LLM_ARCH_GPTNEOX:
11008+
{
11009+
result = llm.build_gptneox();
11010+
} break;
1077711011
default:
1077811012
GGML_ASSERT(false);
1077911013
}
@@ -15766,7 +16000,6 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
1576616000
// these models do not use RoPE
1576716001
case LLM_ARCH_GPT2:
1576816002
case LLM_ARCH_GPTJ:
15769-
case LLM_ARCH_GPTNEOX:
1577016003
case LLM_ARCH_MPT:
1577116004
case LLM_ARCH_REFACT:
1577216005
case LLM_ARCH_BLOOM:
@@ -15802,6 +16035,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
1580216035
case LLM_ARCH_PHI3:
1580316036
case LLM_ARCH_GEMMA:
1580416037
case LLM_ARCH_STARCODER2:
16038+
case LLM_ARCH_GPTNEOX:
1580516039
return LLAMA_ROPE_TYPE_NEOX;
1580616040

1580716041
// all model arches should be listed explicitly here

0 commit comments

Comments
 (0)