Skip to content

Commit 6a0c088

Browse files
zonglinpengfacebook-github-bot
authored andcommitted
quantized_relu_per_tensor_out
Summary: fix quantized_relu_per_tensor Differential Revision: D69015308
1 parent 38e0bc7 commit 6a0c088

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

backends/cadence/hifi/operators/op_quantized_relu_out.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,60 @@ void quantized_relu_(
4545
}
4646
}
4747

48+
void quantized_relu_per_tensor_out(
49+
KernelRuntimeContext& ctx,
50+
const Tensor& input,
51+
const int64_t in_zero_point,
52+
const int64_t out_zero_point,
53+
const int64_t out_multiplier,
54+
const int64_t out_shift,
55+
Tensor& output) {
56+
const uint8_t _in_zero_point = static_cast<uint8_t>(in_zero_point);
57+
const uint8_t _out_zero_point = static_cast<uint8_t>(out_zero_point);
58+
const int32_t _out_multiplier = static_cast<int32_t>(out_multiplier);
59+
const int32_t _out_shift = static_cast<int32_t>(out_shift);
60+
if (input.scalar_type() == executorch::aten::ScalarType::Byte) {
61+
const uint8_t* p_in = input.const_data_ptr<uint8_t>();
62+
uint8_t* p_out = output.mutable_data_ptr<uint8_t>();
63+
64+
WORD32 ret_val = xa_nn_vec_relu_asym8u_asym8u(
65+
p_out,
66+
p_in,
67+
_in_zero_point,
68+
_out_multiplier,
69+
_out_shift,
70+
_out_zero_point,
71+
_out_zero_point,
72+
255,
73+
input.numel());
74+
75+
ET_CHECK_MSG(ret_val == 0, "An internal error occured");
76+
77+
} else if (input.scalar_type() == executorch::aten::ScalarType::Char) {
78+
const int8_t* p_in = input.const_data_ptr<int8_t>();
79+
int8_t* p_out = output.mutable_data_ptr<int8_t>();
80+
81+
WORD32 ret_val = xa_nn_vec_relu_asym8s_asym8s(
82+
p_out,
83+
p_in,
84+
_in_zero_point,
85+
_out_multiplier,
86+
_out_shift,
87+
_out_zero_point,
88+
_out_zero_point,
89+
127,
90+
input.numel());
91+
92+
ET_CHECK_MSG(ret_val == 0, "An internal error occured");
93+
94+
} else {
95+
ET_CHECK_MSG(
96+
false,
97+
"Unhandled input dtype %hhd",
98+
static_cast<int8_t>(input.scalar_type()));
99+
}
100+
}
101+
48102
void quantized_relu_per_tensor_out(
49103
KernelRuntimeContext& ctx,
50104
const Tensor& input,

0 commit comments

Comments
 (0)