@@ -461,88 +461,6 @@ status_t fold_mul_scales(std::shared_ptr<subgraph_t> &sg) {
461461 return status::success;
462462}
463463
464- status_t fold_sum_scales (std::shared_ptr<subgraph_t > &sg) {
465- std::set<op_t *> visited;
466- subgraph_rewriter_t rewriter (sg);
467-
468- for (auto &cur_op : sg->get_ops ()) {
469- if (!(cur_op->get_kind () == op_kind::dnnl_binary
470- && static_cast <dnnl::algorithm>(
471- cur_op->get_attr <int64_t >(op_attr::alg_kind))
472- == dnnl::algorithm::binary_add)
473- || visited.count (cur_op.get ()))
474- continue ;
475-
476- visited.insert (cur_op.get ());
477- size_t mul_scale_op_offset = 2 ;
478- auto lhs_val = cur_op->get_input_value (0 );
479- auto rhs_val = cur_op->get_input_value (1 );
480-
481- if (!lhs_val->has_producer () || !rhs_val->has_producer ()) { continue ; }
482- const auto &l_op = lhs_val->get_producer ();
483- const auto &r_op = rhs_val->get_producer ();
484-
485- auto consumers = cur_op->get_output_values ()[0 ]->get_consumers ();
486- if (consumers.empty ()
487- || consumers[0 ].get_op ().get_kind ()
488- != op_kind::dnnl_mul_scales) {
489- continue ;
490- }
491-
492- if (l_op.get_kind () != op_kind::dnnl_mul_scales
493- || r_op.get_kind () != op_kind::dnnl_mul_scales) {
494- continue ;
495- }
496- if (l_op.num_inputs () > 0 && l_op.get_input_value (0 )->has_producer ()
497- && l_op.get_input_value (0 )->get_producer ().get_kind ()
498- == op_kind::dnnl_reorder) {
499- mul_scale_op_offset = 1 ;
500- } else if (r_op.num_inputs () > 0
501- && r_op.get_input_value (0 )->has_producer ()
502- && r_op.get_input_value (0 )->get_producer ().get_kind ()
503- == op_kind::dnnl_reorder) {
504- mul_scale_op_offset = 0 ;
505- }
506-
507- if (mul_scale_op_offset != 2
508- && ltw (lhs_val->get_logical_tensor ()).vdims ()
509- == ltw (rhs_val->get_logical_tensor ()).vdims ()) {
510- auto in_val = cur_op->get_input_value (mul_scale_op_offset);
511- auto &mul_scale_op = in_val->get_producer ();
512- auto scales = mul_scale_op.get_attr <std::vector<float >>(
513- op_attr::scales);
514- assert (scales.size () == 1 ); // per tensor
515-
516- auto tmp = mul_scale_op.get_input_value (0 );
517- auto &add_zps_op = tmp->get_producer ();
518- auto zps = add_zps_op.get_attr <std::vector<int64_t >>(op_attr::zps);
519- assert (scales.size () == zps.size ());
520-
521- auto out_val = cur_op->get_output_values ()[0 ];
522- auto consumers = out_val->get_consumers ();
523- auto &next_op = consumers[0 ].get_op ();
524- // set sum post-ops' second input scale
525- float tmp_scale
526- = next_op.get_attr <std::vector<float >>(op_attr::scales)[0 ];
527- scales[0 ] *= tmp_scale;
528- mul_scale_op.set_attr <std::vector<float >>(op_attr::scales, scales);
529-
530- // update the output scales
531- auto other_val = cur_op->get_input_value (1 - mul_scale_op_offset);
532- auto &oscales_op = other_val->get_producer ();
533- auto oscales
534- = oscales_op.get_attr <std::vector<float >>(op_attr::scales);
535- for (auto &v : oscales)
536- v *= tmp_scale;
537- oscales_op.set_attr <std::vector<float >>(op_attr::scales, oscales);
538- rewriter.fuse_op_to_predecessor (next_op.shared_from_this ());
539- }
540- }
541-
542- rewriter.run ();
543- return status::success;
544- }
545-
546464// FIXME(xx) This pass works correctly only when all inputs/outputs scales/zps
547465// are same, since we are simply ignoring the scales and zps. We can improve
548466// this pass to support different per-tensor scale since oneDNN concat primitive
0 commit comments