@@ -257,7 +257,7 @@ void create_gpt_mlp(const std::shared_ptr<pb_graph_t> &pgraph,
257257
258258pm::pb_node_t *append_rms_norm_option1 (
259259 const std::shared_ptr<pb_graph_t > &pgraph, pm::pb_node_t *input,
260- bool is_bf16 = false , bool is_int8 = false ) {
260+ bool is_bf16 = false , bool is_int8 = false , bool end_cast = false ) {
261261 if (is_bf16) {
262262 auto typecast = pgraph->append_op (
263263 graph::op_kind::TypeCast, {in_edge (0 , input, 0 )});
@@ -277,14 +277,18 @@ pm::pb_node_t *append_rms_norm_option1(
277277 auto mul2 = pgraph->append_op (
278278 graph::op_kind::Multiply, {in_edge (0 , cast1, 0 )});
279279 mul2->allow_external_outputs ();
280- UNUSED (is_bf16);
280+ pm::pb_node_t *output = mul2;
281+ if (end_cast) {
282+ output = pgraph->append_op (
283+ graph::op_kind::TypeCast, {in_edge (0 , mul2, 0 )});
284+ }
281285 UNUSED (is_int8);
282- return mul2 ;
286+ return output ;
283287};
284288
285289pm::pb_node_t *append_rms_norm_option2 (
286290 const std::shared_ptr<pb_graph_t > &pgraph, pm::pb_node_t *input,
287- bool is_bf16 = false , bool is_int8 = false ) {
291+ bool is_bf16 = false , bool is_int8 = false , bool end_cast = false ) {
288292 pm::pb_node_t *pow_in = input;
289293 pm::pb_node_t *mul1_in = input;
290294 if (is_bf16) {
@@ -311,9 +315,13 @@ pm::pb_node_t *append_rms_norm_option2(
311315 auto mul2 = pgraph->append_op (
312316 graph::op_kind::Multiply, {in_edge (0 , cast1, 0 )});
313317 mul2->allow_external_outputs ();
314- UNUSED (is_bf16);
318+ pm::pb_node_t *output = mul2;
319+ if (end_cast) {
320+ output = pgraph->append_op (
321+ graph::op_kind::TypeCast, {in_edge (0 , mul2, 0 )});
322+ }
315323 UNUSED (is_int8);
316- return mul2 ;
324+ return output ;
317325};
318326
319327/*
@@ -343,20 +351,20 @@ pm::pb_node_t *append_rms_norm_option2(
343351*/
344352void create_llama_mlp (const std::shared_ptr<pb_graph_t > &pgraph,
345353 bool is_bf16 = false , bool is_int8 = false ,
346- bool use_rms_norm_alternative = false ,
347- bool split_smooth_quant = false ) {
354+ bool use_rms_norm_alternative = false , bool split_smooth_quant = false ,
355+ bool end_cast = false ) {
348356 auto matmul1 = create_dequant_matmul (pgraph, nullptr , is_bf16, is_int8);
349357 auto add1
350358 = pgraph->append_op (graph::op_kind::Add, {in_edge (0 , matmul1, 0 )});
351359 add1->allow_external_outputs ();
352360 auto norm1 = use_rms_norm_alternative
353- ? append_rms_norm_option1 (pgraph, add1, is_bf16, is_int8)
354- : append_rms_norm_option2 (pgraph, add1, is_bf16, is_int8);
361+ ? append_rms_norm_option1 (pgraph, add1, is_bf16, is_int8, end_cast )
362+ : append_rms_norm_option2 (pgraph, add1, is_bf16, is_int8, end_cast );
355363
356364 pm::pb_node_t *norm1_for_lhs = norm1, *norm1_for_rhs = norm1;
357365 if (is_int8) {
358366 auto extra_cast_before_mul = append_single_op_repetition_subgraph (
359- pgraph, graph::op_kind::TypeCast, norm1 );
367+ pgraph, graph::op_kind::TypeCast, norm1_for_lhs );
360368 auto smooth_quant_mul1 = append_single_op_repetition_subgraph (
361369 pgraph, graph::op_kind::Multiply, extra_cast_before_mul);
362370 auto extra_cast_after_mul = append_single_op_repetition_subgraph (
@@ -366,7 +374,7 @@ void create_llama_mlp(const std::shared_ptr<pb_graph_t> &pgraph,
366374 if (split_smooth_quant) {
367375 auto extra_cast_before_mul_rhs
368376 = append_single_op_repetition_subgraph (
369- pgraph, graph::op_kind::TypeCast, norm1 );
377+ pgraph, graph::op_kind::TypeCast, norm1_for_rhs );
370378 auto smooth_quant_mul1_rhs = append_single_op_repetition_subgraph (
371379 pgraph, graph::op_kind::Multiply,
372380 extra_cast_before_mul_rhs);
@@ -414,8 +422,8 @@ void create_llama_mlp(const std::shared_ptr<pb_graph_t> &pgraph,
414422 graph::op_kind::Add, {in_edge (0 , matmul4, 0 ), in_edge (1 , add1, 0 )});
415423 add2->allow_external_outputs ();
416424 auto norm2 = use_rms_norm_alternative
417- ? append_rms_norm_option1 (pgraph, add2, is_bf16, is_int8)
418- : append_rms_norm_option2 (pgraph, add2, is_bf16, is_int8);
425+ ? append_rms_norm_option1 (pgraph, add2, is_bf16, is_int8, end_cast )
426+ : append_rms_norm_option2 (pgraph, add2, is_bf16, is_int8, end_cast );
419427 if (is_int8) {
420428 auto extra_cast_before_mul = append_single_op_repetition_subgraph (
421429 pgraph, graph::op_kind::TypeCast, norm2);
@@ -813,6 +821,22 @@ COMPILER_BACKEND_REGISTER_TRANSFORMATION_PASS(compiler, int8_bf16_llama_mlp)
813821 .set_attr<FCreatePattern>(" FCreatePattern" ,
814822 [](const std::shared_ptr<pb_graph_t > &pgraph) -> void {
815823 create_llama_mlp (pgraph, true , true , true , true );
824+ })
825+ .set_attr<FCreatePattern>(" FCreatePattern" ,
826+ [](const std::shared_ptr<pb_graph_t > &pgraph) -> void {
827+ create_llama_mlp (pgraph, true , true , false , false , true );
828+ })
829+ .set_attr<FCreatePattern>(" FCreatePattern" ,
830+ [](const std::shared_ptr<pb_graph_t > &pgraph) -> void {
831+ create_llama_mlp (pgraph, true , true , true , false , true );
832+ })
833+ .set_attr<FCreatePattern>(" FCreatePattern" ,
834+ [](const std::shared_ptr<pb_graph_t > &pgraph) -> void {
835+ create_llama_mlp (pgraph, true , true , false , true , true );
836+ })
837+ .set_attr<FCreatePattern>(" FCreatePattern" ,
838+ [](const std::shared_ptr<pb_graph_t > &pgraph) -> void {
839+ create_llama_mlp (pgraph, true , true , true , true , true );
816840 });
817841COMPILER_BACKEND_REGISTER_PASSES_DEF_END
818842
@@ -1077,6 +1101,14 @@ COMPILER_BACKEND_REGISTER_TRANSFORMATION_PASS(compiler, bf16_llama_mlp)
10771101 .set_attr<FCreatePattern>(" FCreatePattern" ,
10781102 [](const std::shared_ptr<pb_graph_t > &pgraph) -> void {
10791103 create_llama_mlp (pgraph, true , false , true );
1104+ })
1105+ .set_attr<FCreatePattern>(" FCreatePattern" ,
1106+ [](const std::shared_ptr<pb_graph_t > &pgraph) -> void {
1107+ create_llama_mlp (pgraph, true , false , false , false , true );
1108+ })
1109+ .set_attr<FCreatePattern>(" FCreatePattern" ,
1110+ [](const std::shared_ptr<pb_graph_t > &pgraph) -> void {
1111+ create_llama_mlp (pgraph, true , false , true , false , true );
10801112 });
10811113COMPILER_BACKEND_REGISTER_PASSES_DEF_END
10821114
0 commit comments