diff --git a/llama.cpp b/llama.cpp index c236277d437cc..cf3ee494af465 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1116,13 +1116,13 @@ struct llama_layer { struct ggml_tensor * ffn_norm_b; // ff - struct ggml_tensor * w1; // ffn_gate - struct ggml_tensor * w2; // ffn_down - struct ggml_tensor * w3; // ffn_up + struct ggml_tensor * ffn_gate; // w1 + struct ggml_tensor * ffn_down; // w2 + struct ggml_tensor * ffn_up; // w3 // ff bias - struct ggml_tensor * b2; // ffn_down - struct ggml_tensor * b3; // ffn_up + struct ggml_tensor * ffn_down_b; // b2 + struct ggml_tensor * ffn_up_b; // b3 }; struct llama_kv_cell { @@ -2538,15 +2538,15 @@ static void llm_load_tensors( layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); - layer.w1 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); - layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); - layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); if (backend == GGML_BACKEND_GPU) { vram_weights += - ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + - ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + - ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3); + ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + + ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + + ggml_nbytes(layer.ffn_gate) + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up); } } } break; @@ -2604,15 +2604,15 @@ static void llm_load_tensors( layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); - layer.w1 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); - layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); - layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); if (backend == GGML_BACKEND_GPU) { vram_weights += - ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + - ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + - ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3); + ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + + ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + + ggml_nbytes(layer.ffn_gate) + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up); } } } break; @@ -2683,14 +2683,14 @@ static void llm_load_tensors( layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); - layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); - layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); if (backend == GGML_BACKEND_GPU) { vram_weights += ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) + ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.wo) + - ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3); + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up); } } } break; @@ -2756,11 +2756,11 @@ static void llm_load_tensors( layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend); - layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split); - layer.b2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend); + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split); + layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend); - layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); - layer.b3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend); if (backend == GGML_BACKEND_GPU) { vram_weights += @@ -2768,8 +2768,8 @@ static void llm_load_tensors( ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.bqkv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.bo) + ggml_nbytes(layer.ffn_norm) + ggml_nbytes(layer.ffn_norm_b) + - ggml_nbytes(layer.w2) + ggml_nbytes(layer.b2) + - ggml_nbytes(layer.w3) + ggml_nbytes(layer.b3); + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_down_b) + + ggml_nbytes(layer.ffn_up) + ggml_nbytes(layer.ffn_up_b); } } } break; @@ -2816,22 +2816,22 @@ static void llm_load_tensors( const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); - layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); - layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); - layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend_split); - layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); - layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend_split); - layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split); - layer.b2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend_split); - layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); - layer.b3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend_split); - layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); - layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend); + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); + layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); + layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend_split); + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend_split); + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split); + layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend_split); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend_split); + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend); layer.attn_q_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {64}, backend); - layer.attn_q_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {64}, backend); + layer.attn_q_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {64}, backend); layer.attn_k_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {64}, backend); - layer.attn_k_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {64}, backend); + layer.attn_k_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {64}, backend); } } break; case LLM_ARCH_BLOOM: @@ -2899,11 +2899,11 @@ static void llm_load_tensors( layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend); - layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split); - layer.b2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend_split); + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split); + layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend_split); - layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); - layer.b3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend_split); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend_split); if (backend == GGML_BACKEND_GPU) { vram_weights += @@ -2911,8 +2911,8 @@ static void llm_load_tensors( ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.bqkv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.bo) + ggml_nbytes(layer.ffn_norm) + ggml_nbytes(layer.ffn_norm_b) + - ggml_nbytes(layer.w3) + ggml_nbytes(layer.b3) + - ggml_nbytes(layer.w2) + ggml_nbytes(layer.b2); + ggml_nbytes(layer.ffn_up) + ggml_nbytes(layer.ffn_up_b) + + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_down_b); } } } break; @@ -2969,8 +2969,8 @@ static void llm_load_tensors( layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); - layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); - layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); if (backend == GGML_BACKEND_GPU) { vram_weights += @@ -2978,8 +2978,8 @@ static void llm_load_tensors( ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + - ggml_nbytes(layer.w2) + - ggml_nbytes(layer.w3); + ggml_nbytes(layer.ffn_down) + + ggml_nbytes(layer.ffn_up); } } } break; @@ -3129,6 +3129,107 @@ static struct ggml_tensor * llm_build_norm( return cur; } +enum llm_ffn_op_type { + LLM_FFN_SILU, + LLM_FFN_GELU, + LLM_FFN_RELU, + LLM_FFN_RELU_SQR, +}; + +enum llm_ffn_gate_type { + LLM_FFN_SEQ, + LLM_FFN_PAR, // ffn_gate is parallel to ffn_up +}; + +static struct ggml_tensor * llm_build_ffn( + struct ggml_context * ctx, + struct ggml_tensor * cur, + struct ggml_tensor * up, + struct ggml_tensor * up_b, + struct ggml_tensor * gate, + struct ggml_tensor * gate_b, + struct ggml_tensor * down, + struct ggml_tensor * down_b, + llm_ffn_op_type type_op, + llm_ffn_gate_type type_gate, + const llm_build_cb & cb, + int il) { + struct ggml_tensor * tmp = ggml_mul_mat(ctx, up, cur); + cb(tmp, "ffn_up", il); + + if (up_b) { + tmp = ggml_add(ctx, tmp, up_b); + cb(tmp, "ffn_up_b", il); + } + + if (gate) { + switch (type_gate) { + case LLM_FFN_SEQ: + { + cur = ggml_mul_mat(ctx, gate, tmp); + cb(cur, "ffn_gate", il); + + if (gate_b) { + cur = ggml_add(ctx, cur, gate_b); + cb(cur, "ffn_gate_b", il); + } + } break; + case LLM_FFN_PAR: + { + cur = ggml_mul_mat(ctx, gate, cur); + cb(cur, "ffn_gate", il); + + if (gate_b) { + cur = ggml_add(ctx, cur, gate_b); + cb(cur, "ffn_gate_b", il); + } + } break; + }; + } + + switch (type_op) { + case LLM_FFN_SILU: + { + cur = ggml_silu(ctx, cur); + cb(cur, "ffn_silu", il); + } break; + case LLM_FFN_GELU: + { + cur = ggml_gelu(ctx, cur); + cb(cur, "ffn_gelu", il); + } break; + case LLM_FFN_RELU: + { + cur = ggml_relu(ctx, cur); + cb(cur, "ffn_relu", il); + } break; + case LLM_FFN_RELU_SQR: + { + cur = ggml_relu(ctx, cur); + cb(cur, "ffn_relu", il); + + cur = ggml_sqr(ctx, cur); + cb(cur, "ffn_sqr(relu)", il); + } break; + }; + + if (type_gate == LLM_FFN_PAR) { + cur = ggml_mul(ctx, cur, tmp); + cb(cur, "ffn_gate_par", il); + } + + cur = ggml_mul_mat(ctx, down, cur); + if (down_b) { + cb(cur, "ffn_down", il); + } + + if (down_b) { + cur = ggml_add(ctx, cur, down_b); + } + + return cur; +} + static struct ggml_cgraph * llm_build_llama( llama_context & lctx, const llama_batch & batch, @@ -3346,27 +3447,12 @@ static struct ggml_cgraph * llm_build_llama( LLM_NORM_RMS, norm_rms_eps, cb, il); cb(cur, "ffn_norm", il); - struct ggml_tensor * tmp = ggml_mul_mat(ctx0, - model.layers[il].w3, - cur); - cb(tmp, "result_w3", il); - - cur = ggml_mul_mat(ctx0, - model.layers[il].w1, - cur); - cb(cur, "result_w1", il); - - // SILU activation - cur = ggml_silu(ctx0, cur); - cb(cur, "silu", il); - - cur = ggml_mul(ctx0, cur, tmp); - cb(cur, "silu_x_result_w3", il); - - cur = ggml_mul_mat(ctx0, - model.layers[il].w2, - cur); - cb(cur, "result_w2", il); + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_result", il); } cur = ggml_add(ctx0, cur, inpFF); @@ -3627,27 +3713,12 @@ static struct ggml_cgraph * llm_build_baichaun( LLM_NORM_RMS, norm_rms_eps, cb, il); cb(cur, "ffn_norm", il); - struct ggml_tensor * tmp = ggml_mul_mat(ctx0, - model.layers[il].w3, - cur); - cb(tmp, "result_w3", il); - - cur = ggml_mul_mat(ctx0, - model.layers[il].w1, - cur); - cb(cur, "result_w1", il); - - // SILU activation - cur = ggml_silu(ctx0, cur); - cb(cur, "silu", il); - - cur = ggml_mul(ctx0, cur, tmp); - cb(cur, "silu_x_result_w3", il); - - cur = ggml_mul_mat(ctx0, - model.layers[il].w2, - cur); - cb(cur, "result_w2", il); + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_result", il); } cur = ggml_add(ctx0, cur, inpFF); @@ -3911,16 +3982,12 @@ static struct ggml_cgraph * llm_build_falcon( // feed forward { - struct ggml_tensor * inpFF = attn_norm; - - cur = ggml_mul_mat(ctx0, model.layers[il].w3, inpFF); - cb(cur, "result_w3", il); - - cur = ggml_gelu(ctx0, cur); - cb(cur, "gelu", il); - - cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); - cb(cur, "result_w2", il); + cur = llm_build_ffn(ctx0, attn_norm, // !! use the attn norm, not the result + model.layers[il].ffn_up, NULL, + NULL, NULL, + model.layers[il].ffn_down, NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_result", il); } cur = ggml_add(ctx0, cur, attn_out); @@ -4136,19 +4203,12 @@ static struct ggml_cgraph * llm_build_starcoder( LLM_NORM, norm_eps, cb, il); cb(cur, "ffn_norm", il); - cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w3, cur), model.layers[il].b3); - cb(cur, "result_w3", il); - - // GELU activation - cur = ggml_gelu(ctx0, cur); - cb(cur, "gelu", il); - - // Projection - cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); - cb(cur, "result_w2", il); - - cur = ggml_add(ctx0, cur, model.layers[il].b2); - cb(cur, "result_w2_b", il); + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, + NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, + LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_result", il); } inpL = ggml_add(ctx0, cur, inpFF); @@ -4455,31 +4515,20 @@ static struct ggml_cgraph * llm_build_persimmon( struct ggml_tensor * inpFF = ggml_add(ctx0, residual, cur); cb(inpFF, "inpFF", il); + // feed-forward network { - // MLP cur = llm_build_norm(ctx0, inpFF, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, norm_eps, cb, il); cb(cur, "ffn_norm", il); - cur = ggml_mul_mat(ctx0, model.layers[il].w3, cur); - cb(cur, "result_w3", il); - - cur = ggml_add(ctx0, cur, model.layers[il].b3); - cb(cur, "result_w3_b", il); - - cur = ggml_relu(ctx0, cur); - cb(cur, "relu", il); - - cur = ggml_sqr(ctx0, cur); - cb(cur, "sqr(relu)", il); - - cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); - cb(cur, "result_w2", il); - - cur = ggml_add(ctx0, cur, model.layers[il].b2); - cb(cur, "result_w2_b", il); + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, + NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, + LLM_FFN_RELU_SQR, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_result", il); } cur = ggml_add(ctx0, cur, inpFF); @@ -4687,27 +4736,12 @@ static struct ggml_cgraph * llm_build_refact( LLM_NORM_RMS, norm_rms_eps, cb, il); cb(cur, "ffn_norm", il); - struct ggml_tensor * tmp = ggml_mul_mat(ctx0, - model.layers[il].w3, - cur); - cb(tmp, "result_w3", il); - - cur = ggml_mul_mat(ctx0, - model.layers[il].w1, - cur); - cb(cur, "result_w1", il); - - // SILU activation - cur = ggml_silu(ctx0, cur); - cb(cur, "silu", il); - - cur = ggml_mul(ctx0, cur, tmp); - cb(cur, "silu_x_result_w3", il); - - cur = ggml_mul_mat(ctx0, - model.layers[il].w2, - cur); - cb(cur, "result_w2", il); + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_result", il); } cur = ggml_add(ctx0, cur, inpFF); @@ -4932,20 +4966,12 @@ static struct ggml_cgraph * llm_build_bloom( LLM_NORM, norm_eps, cb, il); cb(cur, "ffn_norm", il); - cur = ggml_mul_mat(ctx0, model.layers[il].w3, cur); - cb(cur, "result_w3", il); - - cur = ggml_add(ctx0, cur, model.layers[il].b3); - cb(cur, "result_w3_b", il); - - cur = ggml_gelu(ctx0, cur); - cb(cur, "gelu", il); - - cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); - cb(cur, "result_w2", il); - - cur = ggml_add(ctx0, cur, model.layers[il].b2); - cb(cur, "result_w2_b", il); + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, + NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, + LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_result", il); } inpL = ggml_add(ctx0, cur, inpFF); @@ -5163,14 +5189,12 @@ static struct ggml_cgraph * llm_build_mpt( LLM_NORM, norm_eps, cb, il); cb(cur, "ffn_norm", il); - cur = ggml_mul_mat(ctx0, model.layers[il].w3, cur); - cb(cur, "result_w3", il); - - cur = ggml_gelu(ctx0, cur); - cb(cur, "gelu", il); - - cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); - cb(cur, "result_w2", il); + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + NULL, NULL, + model.layers[il].ffn_down, NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_result", il); } cur = ggml_add(ctx0, cur, attn_out);