Skip to content

Commit ff680fc

Browse files
yifeizh2vpirogov
authored andcommitted
graph: backend: compiler: update llama mlp
1 parent d6c216a commit ff680fc

File tree

1 file changed

+46
-14
lines changed

1 file changed

+46
-14
lines changed

src/graph/backend/graph_compiler/patterns/mlp_pattern.hpp

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ void create_gpt_mlp(const std::shared_ptr<pb_graph_t> &pgraph,
257257

258258
pm::pb_node_t *append_rms_norm_option1(
259259
const std::shared_ptr<pb_graph_t> &pgraph, pm::pb_node_t *input,
260-
bool is_bf16 = false, bool is_int8 = false) {
260+
bool is_bf16 = false, bool is_int8 = false, bool end_cast = false) {
261261
if (is_bf16) {
262262
auto typecast = pgraph->append_op(
263263
graph::op_kind::TypeCast, {in_edge(0, input, 0)});
@@ -277,14 +277,18 @@ pm::pb_node_t *append_rms_norm_option1(
277277
auto mul2 = pgraph->append_op(
278278
graph::op_kind::Multiply, {in_edge(0, cast1, 0)});
279279
mul2->allow_external_outputs();
280-
UNUSED(is_bf16);
280+
pm::pb_node_t *output = mul2;
281+
if (end_cast) {
282+
output = pgraph->append_op(
283+
graph::op_kind::TypeCast, {in_edge(0, mul2, 0)});
284+
}
281285
UNUSED(is_int8);
282-
return mul2;
286+
return output;
283287
};
284288

285289
pm::pb_node_t *append_rms_norm_option2(
286290
const std::shared_ptr<pb_graph_t> &pgraph, pm::pb_node_t *input,
287-
bool is_bf16 = false, bool is_int8 = false) {
291+
bool is_bf16 = false, bool is_int8 = false, bool end_cast = false) {
288292
pm::pb_node_t *pow_in = input;
289293
pm::pb_node_t *mul1_in = input;
290294
if (is_bf16) {
@@ -311,9 +315,13 @@ pm::pb_node_t *append_rms_norm_option2(
311315
auto mul2 = pgraph->append_op(
312316
graph::op_kind::Multiply, {in_edge(0, cast1, 0)});
313317
mul2->allow_external_outputs();
314-
UNUSED(is_bf16);
318+
pm::pb_node_t *output = mul2;
319+
if (end_cast) {
320+
output = pgraph->append_op(
321+
graph::op_kind::TypeCast, {in_edge(0, mul2, 0)});
322+
}
315323
UNUSED(is_int8);
316-
return mul2;
324+
return output;
317325
};
318326

319327
/*
@@ -343,20 +351,20 @@ pm::pb_node_t *append_rms_norm_option2(
343351
*/
344352
void create_llama_mlp(const std::shared_ptr<pb_graph_t> &pgraph,
345353
bool is_bf16 = false, bool is_int8 = false,
346-
bool use_rms_norm_alternative = false,
347-
bool split_smooth_quant = false) {
354+
bool use_rms_norm_alternative = false, bool split_smooth_quant = false,
355+
bool end_cast = false) {
348356
auto matmul1 = create_dequant_matmul(pgraph, nullptr, is_bf16, is_int8);
349357
auto add1
350358
= pgraph->append_op(graph::op_kind::Add, {in_edge(0, matmul1, 0)});
351359
add1->allow_external_outputs();
352360
auto norm1 = use_rms_norm_alternative
353-
? append_rms_norm_option1(pgraph, add1, is_bf16, is_int8)
354-
: append_rms_norm_option2(pgraph, add1, is_bf16, is_int8);
361+
? append_rms_norm_option1(pgraph, add1, is_bf16, is_int8, end_cast)
362+
: append_rms_norm_option2(pgraph, add1, is_bf16, is_int8, end_cast);
355363

356364
pm::pb_node_t *norm1_for_lhs = norm1, *norm1_for_rhs = norm1;
357365
if (is_int8) {
358366
auto extra_cast_before_mul = append_single_op_repetition_subgraph(
359-
pgraph, graph::op_kind::TypeCast, norm1);
367+
pgraph, graph::op_kind::TypeCast, norm1_for_lhs);
360368
auto smooth_quant_mul1 = append_single_op_repetition_subgraph(
361369
pgraph, graph::op_kind::Multiply, extra_cast_before_mul);
362370
auto extra_cast_after_mul = append_single_op_repetition_subgraph(
@@ -366,7 +374,7 @@ void create_llama_mlp(const std::shared_ptr<pb_graph_t> &pgraph,
366374
if (split_smooth_quant) {
367375
auto extra_cast_before_mul_rhs
368376
= append_single_op_repetition_subgraph(
369-
pgraph, graph::op_kind::TypeCast, norm1);
377+
pgraph, graph::op_kind::TypeCast, norm1_for_rhs);
370378
auto smooth_quant_mul1_rhs = append_single_op_repetition_subgraph(
371379
pgraph, graph::op_kind::Multiply,
372380
extra_cast_before_mul_rhs);
@@ -414,8 +422,8 @@ void create_llama_mlp(const std::shared_ptr<pb_graph_t> &pgraph,
414422
graph::op_kind::Add, {in_edge(0, matmul4, 0), in_edge(1, add1, 0)});
415423
add2->allow_external_outputs();
416424
auto norm2 = use_rms_norm_alternative
417-
? append_rms_norm_option1(pgraph, add2, is_bf16, is_int8)
418-
: append_rms_norm_option2(pgraph, add2, is_bf16, is_int8);
425+
? append_rms_norm_option1(pgraph, add2, is_bf16, is_int8, end_cast)
426+
: append_rms_norm_option2(pgraph, add2, is_bf16, is_int8, end_cast);
419427
if (is_int8) {
420428
auto extra_cast_before_mul = append_single_op_repetition_subgraph(
421429
pgraph, graph::op_kind::TypeCast, norm2);
@@ -813,6 +821,22 @@ COMPILER_BACKEND_REGISTER_TRANSFORMATION_PASS(compiler, int8_bf16_llama_mlp)
813821
.set_attr<FCreatePattern>("FCreatePattern",
814822
[](const std::shared_ptr<pb_graph_t> &pgraph) -> void {
815823
create_llama_mlp(pgraph, true, true, true, true);
824+
})
825+
.set_attr<FCreatePattern>("FCreatePattern",
826+
[](const std::shared_ptr<pb_graph_t> &pgraph) -> void {
827+
create_llama_mlp(pgraph, true, true, false, false, true);
828+
})
829+
.set_attr<FCreatePattern>("FCreatePattern",
830+
[](const std::shared_ptr<pb_graph_t> &pgraph) -> void {
831+
create_llama_mlp(pgraph, true, true, true, false, true);
832+
})
833+
.set_attr<FCreatePattern>("FCreatePattern",
834+
[](const std::shared_ptr<pb_graph_t> &pgraph) -> void {
835+
create_llama_mlp(pgraph, true, true, false, true, true);
836+
})
837+
.set_attr<FCreatePattern>("FCreatePattern",
838+
[](const std::shared_ptr<pb_graph_t> &pgraph) -> void {
839+
create_llama_mlp(pgraph, true, true, true, true, true);
816840
});
817841
COMPILER_BACKEND_REGISTER_PASSES_DEF_END
818842

@@ -1077,6 +1101,14 @@ COMPILER_BACKEND_REGISTER_TRANSFORMATION_PASS(compiler, bf16_llama_mlp)
10771101
.set_attr<FCreatePattern>("FCreatePattern",
10781102
[](const std::shared_ptr<pb_graph_t> &pgraph) -> void {
10791103
create_llama_mlp(pgraph, true, false, true);
1104+
})
1105+
.set_attr<FCreatePattern>("FCreatePattern",
1106+
[](const std::shared_ptr<pb_graph_t> &pgraph) -> void {
1107+
create_llama_mlp(pgraph, true, false, false, false, true);
1108+
})
1109+
.set_attr<FCreatePattern>("FCreatePattern",
1110+
[](const std::shared_ptr<pb_graph_t> &pgraph) -> void {
1111+
create_llama_mlp(pgraph, true, false, true, false, true);
10801112
});
10811113
COMPILER_BACKEND_REGISTER_PASSES_DEF_END
10821114

0 commit comments

Comments
 (0)