Skip to content

Commit 699ba75

Browse files
xinyu-intelTaoLv
authored andcommitted
graph: backend: remove fold_sum_scales pass
1 parent b8d21a5 commit 699ba75

File tree

1 file changed

+0
-82
lines changed

1 file changed

+0
-82
lines changed

src/graph/backend/dnnl/passes/transform.cpp

Lines changed: 0 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -461,88 +461,6 @@ status_t fold_mul_scales(std::shared_ptr<subgraph_t> &sg) {
461461
return status::success;
462462
}
463463

464-
status_t fold_sum_scales(std::shared_ptr<subgraph_t> &sg) {
465-
std::set<op_t *> visited;
466-
subgraph_rewriter_t rewriter(sg);
467-
468-
for (auto &cur_op : sg->get_ops()) {
469-
if (!(cur_op->get_kind() == op_kind::dnnl_binary
470-
&& static_cast<dnnl::algorithm>(
471-
cur_op->get_attr<int64_t>(op_attr::alg_kind))
472-
== dnnl::algorithm::binary_add)
473-
|| visited.count(cur_op.get()))
474-
continue;
475-
476-
visited.insert(cur_op.get());
477-
size_t mul_scale_op_offset = 2;
478-
auto lhs_val = cur_op->get_input_value(0);
479-
auto rhs_val = cur_op->get_input_value(1);
480-
481-
if (!lhs_val->has_producer() || !rhs_val->has_producer()) { continue; }
482-
const auto &l_op = lhs_val->get_producer();
483-
const auto &r_op = rhs_val->get_producer();
484-
485-
auto consumers = cur_op->get_output_values()[0]->get_consumers();
486-
if (consumers.empty()
487-
|| consumers[0].get_op().get_kind()
488-
!= op_kind::dnnl_mul_scales) {
489-
continue;
490-
}
491-
492-
if (l_op.get_kind() != op_kind::dnnl_mul_scales
493-
|| r_op.get_kind() != op_kind::dnnl_mul_scales) {
494-
continue;
495-
}
496-
if (l_op.num_inputs() > 0 && l_op.get_input_value(0)->has_producer()
497-
&& l_op.get_input_value(0)->get_producer().get_kind()
498-
== op_kind::dnnl_reorder) {
499-
mul_scale_op_offset = 1;
500-
} else if (r_op.num_inputs() > 0
501-
&& r_op.get_input_value(0)->has_producer()
502-
&& r_op.get_input_value(0)->get_producer().get_kind()
503-
== op_kind::dnnl_reorder) {
504-
mul_scale_op_offset = 0;
505-
}
506-
507-
if (mul_scale_op_offset != 2
508-
&& ltw(lhs_val->get_logical_tensor()).vdims()
509-
== ltw(rhs_val->get_logical_tensor()).vdims()) {
510-
auto in_val = cur_op->get_input_value(mul_scale_op_offset);
511-
auto &mul_scale_op = in_val->get_producer();
512-
auto scales = mul_scale_op.get_attr<std::vector<float>>(
513-
op_attr::scales);
514-
assert(scales.size() == 1); // per tensor
515-
516-
auto tmp = mul_scale_op.get_input_value(0);
517-
auto &add_zps_op = tmp->get_producer();
518-
auto zps = add_zps_op.get_attr<std::vector<int64_t>>(op_attr::zps);
519-
assert(scales.size() == zps.size());
520-
521-
auto out_val = cur_op->get_output_values()[0];
522-
auto consumers = out_val->get_consumers();
523-
auto &next_op = consumers[0].get_op();
524-
// set sum post-ops' second input scale
525-
float tmp_scale
526-
= next_op.get_attr<std::vector<float>>(op_attr::scales)[0];
527-
scales[0] *= tmp_scale;
528-
mul_scale_op.set_attr<std::vector<float>>(op_attr::scales, scales);
529-
530-
// update the output scales
531-
auto other_val = cur_op->get_input_value(1 - mul_scale_op_offset);
532-
auto &oscales_op = other_val->get_producer();
533-
auto oscales
534-
= oscales_op.get_attr<std::vector<float>>(op_attr::scales);
535-
for (auto &v : oscales)
536-
v *= tmp_scale;
537-
oscales_op.set_attr<std::vector<float>>(op_attr::scales, oscales);
538-
rewriter.fuse_op_to_predecessor(next_op.shared_from_this());
539-
}
540-
}
541-
542-
rewriter.run();
543-
return status::success;
544-
}
545-
546464
// FIXME(xx) This pass works correctly only when all inputs/outputs scales/zps
547465
// are same, since we are simply ignoring the scales and zps. We can improve
548466
// this pass to support different per-tensor scale since oneDNN concat primitive

0 commit comments

Comments
 (0)