Skip to content

Commit 56baff7

Browse files
authored
[cadence][hifi] update quantized_relu_per_tensor_out signature to match internal flow
Differential Revision: D69015308 Pull Request resolved: #8143
1 parent baa5ec7 commit 56baff7

File tree

1 file changed

+41
-15
lines changed

1 file changed

+41
-15
lines changed

backends/cadence/hifi/operators/op_quantized_relu_out.cpp

+41-15
Original file line numberDiff line numberDiff line change
@@ -48,24 +48,27 @@ void quantized_relu_(
4848
void quantized_relu_per_tensor_out(
4949
KernelRuntimeContext& ctx,
5050
const Tensor& input,
51-
const Tensor& in_zero_point,
51+
const int64_t in_zero_point,
5252
const int64_t out_zero_point,
53-
const Tensor& out_multiplier,
54-
const Tensor& out_shift,
53+
const int64_t out_multiplier,
54+
const int64_t out_shift,
5555
Tensor& output) {
56+
const uint8_t _in_zero_point = static_cast<uint8_t>(in_zero_point);
57+
const uint8_t _out_zero_point = static_cast<uint8_t>(out_zero_point);
58+
const int32_t _out_multiplier = static_cast<int32_t>(out_multiplier);
59+
const int32_t _out_shift = static_cast<int32_t>(out_shift);
5660
if (input.scalar_type() == executorch::aten::ScalarType::Byte) {
5761
const uint8_t* p_in = input.const_data_ptr<uint8_t>();
5862
uint8_t* p_out = output.mutable_data_ptr<uint8_t>();
59-
uint8_t q_zero_point = in_zero_point.const_data_ptr<uint8_t>()[0];
6063

6164
WORD32 ret_val = xa_nn_vec_relu_asym8u_asym8u(
6265
p_out,
6366
p_in,
64-
(int)q_zero_point,
65-
out_multiplier.const_data_ptr<int32_t>()[0],
66-
out_shift.const_data_ptr<int32_t>()[0],
67-
(int)out_zero_point,
68-
(int)out_zero_point,
67+
_in_zero_point,
68+
_out_multiplier,
69+
_out_shift,
70+
_out_zero_point,
71+
_out_zero_point,
6972
255,
7073
input.numel());
7174

@@ -74,16 +77,15 @@ void quantized_relu_per_tensor_out(
7477
} else if (input.scalar_type() == executorch::aten::ScalarType::Char) {
7578
const int8_t* p_in = input.const_data_ptr<int8_t>();
7679
int8_t* p_out = output.mutable_data_ptr<int8_t>();
77-
int8_t q_zero_point = in_zero_point.const_data_ptr<int8_t>()[0];
7880

7981
WORD32 ret_val = xa_nn_vec_relu_asym8s_asym8s(
8082
p_out,
8183
p_in,
82-
(int)q_zero_point,
83-
out_multiplier.const_data_ptr<int32_t>()[0],
84-
out_shift.const_data_ptr<int32_t>()[0],
85-
(int)out_zero_point,
86-
(int)out_zero_point,
84+
_in_zero_point,
85+
_out_multiplier,
86+
_out_shift,
87+
_out_zero_point,
88+
_out_zero_point,
8789
127,
8890
input.numel());
8991

@@ -97,6 +99,30 @@ void quantized_relu_per_tensor_out(
9799
}
98100
}
99101

102+
void quantized_relu_per_tensor_out(
103+
KernelRuntimeContext& ctx,
104+
const Tensor& input,
105+
const Tensor& in_zero_point,
106+
const int64_t out_zero_point,
107+
const Tensor& out_multiplier,
108+
const Tensor& out_shift,
109+
Tensor& output) {
110+
const uint8_t* p_in = input.const_data_ptr<uint8_t>();
111+
uint8_t* p_out = output.mutable_data_ptr<uint8_t>();
112+
uint8_t _in_zero_point = in_zero_point.const_data_ptr<uint8_t>()[0];
113+
int32_t _out_multiplier = out_multiplier.const_data_ptr<int32_t>()[0];
114+
int32_t _out_shift = out_shift.const_data_ptr<int32_t>()[0];
115+
116+
quantized_relu_per_tensor_out(
117+
ctx,
118+
input,
119+
_in_zero_point,
120+
out_zero_point,
121+
_out_multiplier,
122+
_out_shift,
123+
output);
124+
}
125+
100126
} // namespace native
101127
} // namespace HiFi
102128
} // namespace impl

0 commit comments

Comments
 (0)