From ae50298d01ae20a2736ef0c332335946ebb5acf7 Mon Sep 17 00:00:00 2001 From: Zonglin Peng Date: Tue, 4 Feb 2025 14:30:04 -0800 Subject: [PATCH] Fixed quantized_relu_per_tensor_out (#8143) Summary: fix quantized_relu_per_tensor Differential Revision: D69015308 --- .../hifi/operators/op_quantized_relu_out.cpp | 56 ++++++++++++++----- 1 file changed, 41 insertions(+), 15 deletions(-) diff --git a/backends/cadence/hifi/operators/op_quantized_relu_out.cpp b/backends/cadence/hifi/operators/op_quantized_relu_out.cpp index 0860109f7c1..06eb8aa3a5b 100644 --- a/backends/cadence/hifi/operators/op_quantized_relu_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_relu_out.cpp @@ -48,24 +48,27 @@ void quantized_relu_( void quantized_relu_per_tensor_out( KernelRuntimeContext& ctx, const Tensor& input, - const Tensor& in_zero_point, + const int64_t in_zero_point, const int64_t out_zero_point, - const Tensor& out_multiplier, - const Tensor& out_shift, + const int64_t out_multiplier, + const int64_t out_shift, Tensor& output) { + const uint8_t _in_zero_point = static_cast(in_zero_point); + const uint8_t _out_zero_point = static_cast(out_zero_point); + const int32_t _out_multiplier = static_cast(out_multiplier); + const int32_t _out_shift = static_cast(out_shift); if (input.scalar_type() == executorch::aten::ScalarType::Byte) { const uint8_t* p_in = input.const_data_ptr(); uint8_t* p_out = output.mutable_data_ptr(); - uint8_t q_zero_point = in_zero_point.const_data_ptr()[0]; WORD32 ret_val = xa_nn_vec_relu_asym8u_asym8u( p_out, p_in, - (int)q_zero_point, - out_multiplier.const_data_ptr()[0], - out_shift.const_data_ptr()[0], - (int)out_zero_point, - (int)out_zero_point, + _in_zero_point, + _out_multiplier, + _out_shift, + _out_zero_point, + _out_zero_point, 255, input.numel()); @@ -74,16 +77,15 @@ void quantized_relu_per_tensor_out( } else if (input.scalar_type() == executorch::aten::ScalarType::Char) { const int8_t* p_in = input.const_data_ptr(); int8_t* p_out = output.mutable_data_ptr(); - int8_t q_zero_point = in_zero_point.const_data_ptr()[0]; WORD32 ret_val = xa_nn_vec_relu_asym8s_asym8s( p_out, p_in, - (int)q_zero_point, - out_multiplier.const_data_ptr()[0], - out_shift.const_data_ptr()[0], - (int)out_zero_point, - (int)out_zero_point, + _in_zero_point, + _out_multiplier, + _out_shift, + _out_zero_point, + _out_zero_point, 127, input.numel()); @@ -97,6 +99,30 @@ void quantized_relu_per_tensor_out( } } +void quantized_relu_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& in_zero_point, + const int64_t out_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + Tensor& output) { + const uint8_t* p_in = input.const_data_ptr(); + uint8_t* p_out = output.mutable_data_ptr(); + uint8_t _in_zero_point = in_zero_point.const_data_ptr()[0]; + int32_t _out_multiplier = out_multiplier.const_data_ptr()[0]; + int32_t _out_shift = out_shift.const_data_ptr()[0]; + + quantized_relu_per_tensor_out( + ctx, + input, + _in_zero_point, + out_zero_point, + _out_multiplier, + _out_shift, + output); +} + } // namespace native } // namespace HiFi } // namespace impl