Skip to content

Commit 9421fb2

Browse files
xinyu-intelTaoLv
authored andcommitted
graph: backend: use post binary for float psrc and int dst
1 parent 9e2e266 commit 9421fb2

File tree

2 files changed

+31
-11
lines changed

2 files changed

+31
-11
lines changed

src/graph/backend/dnnl/fusion_info.cpp

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2022 Intel Corporation
2+
* Copyright 2022-2023 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -146,20 +146,34 @@ dnnl::primitive_attr make_dnnl_primitive_attr(
146146
// post-sum
147147
float scale = pop->get_scale();
148148
int32_t zp = pop->get_zp();
149-
dnnl::memory::data_type sum_dt = dnnl::memory::data_type::undef;
150-
if (op->get_kind() == op_kind::dnnl_convolution) {
151-
const auto psrc_dt = op->get_input_value(extra_inputs[0])
152-
->get_logical_tensor()
153-
.data_type;
154-
const auto dst_dt = op->get_output_value(0)
155-
->get_logical_tensor()
156-
.data_type;
149+
const auto psrc_dt = op->get_input_value(extra_inputs[0])
150+
->get_logical_tensor()
151+
.data_type;
152+
const auto dst_dt = op->get_output_value(0)
153+
->get_logical_tensor()
154+
.data_type;
155+
// note that onednn doesn't support float post-sum with u8/s8
156+
// dst. use post-binary for such case instead.
157+
if (impl::utils::one_of(
158+
dst_dt, impl::data_type::u8, impl::data_type::s8)
159+
&& impl::utils::one_of(psrc_dt, impl::data_type::f32,
160+
impl::data_type::bf16)
161+
&& scale == 1.f && zp == 0) {
162+
auto input = op->get_input_value(extra_inputs[0]);
163+
auto md = make_dnnl_memory_desc(
164+
input->get_logical_tensor());
165+
dnnl_pops.append_binary(dnnl::algorithm::binary_add, md);
166+
op->remove_attr(op_attr::with_sum);
167+
pop->to_post_binary();
168+
} else {
169+
dnnl::memory::data_type sum_dt
170+
= dnnl::memory::data_type::undef;
157171
if (psrc_dt == impl::data_type::s8
158172
&& dst_dt == impl::data_type::u8) {
159173
sum_dt = dnnl::memory::data_type::s8;
160174
}
175+
dnnl_pops.append_sum(scale, zp, sum_dt);
161176
}
162-
dnnl_pops.append_sum(scale, zp, sum_dt);
163177
} else {
164178
// post-binary
165179
assertm(extra_inputs.size() == 1,

src/graph/backend/dnnl/fusion_info.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2022 Intel Corporation
2+
* Copyright 2022-2023 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -84,6 +84,12 @@ class fusion_info_t {
8484
return op_->get_kind() == op_kind::dnnl_binary && !is_post_sum_;
8585
}
8686

87+
void to_post_binary() {
88+
assertm(scale_ == 1.0f && zp_ == 0,
89+
"post bianry cannot support scale and zp!");
90+
is_post_sum_ = false;
91+
}
92+
8793
private:
8894
std::shared_ptr<op_t> op_;
8995
// used to represent post-eltwise and post-sum's scale

0 commit comments

Comments
 (0)