|
1 | 1 | /******************************************************************************* |
2 | | - * Copyright 2022 Intel Corporation |
| 2 | + * Copyright 2022-2023 Intel Corporation |
3 | 3 | * |
4 | 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | 5 | * you may not use this file except in compliance with the License. |
@@ -146,20 +146,34 @@ dnnl::primitive_attr make_dnnl_primitive_attr( |
146 | 146 | // post-sum |
147 | 147 | float scale = pop->get_scale(); |
148 | 148 | int32_t zp = pop->get_zp(); |
149 | | - dnnl::memory::data_type sum_dt = dnnl::memory::data_type::undef; |
150 | | - if (op->get_kind() == op_kind::dnnl_convolution) { |
151 | | - const auto psrc_dt = op->get_input_value(extra_inputs[0]) |
152 | | - ->get_logical_tensor() |
153 | | - .data_type; |
154 | | - const auto dst_dt = op->get_output_value(0) |
155 | | - ->get_logical_tensor() |
156 | | - .data_type; |
| 149 | + const auto psrc_dt = op->get_input_value(extra_inputs[0]) |
| 150 | + ->get_logical_tensor() |
| 151 | + .data_type; |
| 152 | + const auto dst_dt = op->get_output_value(0) |
| 153 | + ->get_logical_tensor() |
| 154 | + .data_type; |
| 155 | + // note that onednn doesn't support float post-sum with u8/s8 |
| 156 | + // dst. use post-binary for such case instead. |
| 157 | + if (impl::utils::one_of( |
| 158 | + dst_dt, impl::data_type::u8, impl::data_type::s8) |
| 159 | + && impl::utils::one_of(psrc_dt, impl::data_type::f32, |
| 160 | + impl::data_type::bf16) |
| 161 | + && scale == 1.f && zp == 0) { |
| 162 | + auto input = op->get_input_value(extra_inputs[0]); |
| 163 | + auto md = make_dnnl_memory_desc( |
| 164 | + input->get_logical_tensor()); |
| 165 | + dnnl_pops.append_binary(dnnl::algorithm::binary_add, md); |
| 166 | + op->remove_attr(op_attr::with_sum); |
| 167 | + pop->to_post_binary(); |
| 168 | + } else { |
| 169 | + dnnl::memory::data_type sum_dt |
| 170 | + = dnnl::memory::data_type::undef; |
157 | 171 | if (psrc_dt == impl::data_type::s8 |
158 | 172 | && dst_dt == impl::data_type::u8) { |
159 | 173 | sum_dt = dnnl::memory::data_type::s8; |
160 | 174 | } |
| 175 | + dnnl_pops.append_sum(scale, zp, sum_dt); |
161 | 176 | } |
162 | | - dnnl_pops.append_sum(scale, zp, sum_dt); |
163 | 177 | } else { |
164 | 178 | // post-binary |
165 | 179 | assertm(extra_inputs.size() == 1, |
|
0 commit comments