@@ -48,24 +48,27 @@ void quantized_relu_(
48
48
void quantized_relu_per_tensor_out (
49
49
KernelRuntimeContext& ctx,
50
50
const Tensor& input,
51
- const Tensor& in_zero_point,
51
+ const int64_t in_zero_point,
52
52
const int64_t out_zero_point,
53
- const Tensor& out_multiplier,
54
- const Tensor& out_shift,
53
+ const int64_t out_multiplier,
54
+ const int64_t out_shift,
55
55
Tensor& output) {
56
+ const uint8_t _in_zero_point = static_cast <uint8_t >(in_zero_point);
57
+ const uint8_t _out_zero_point = static_cast <uint8_t >(out_zero_point);
58
+ const int32_t _out_multiplier = static_cast <int32_t >(out_multiplier);
59
+ const int32_t _out_shift = static_cast <int32_t >(out_shift);
56
60
if (input.scalar_type () == executorch::aten::ScalarType::Byte ) {
57
61
const uint8_t * p_in = input.const_data_ptr <uint8_t >();
58
62
uint8_t * p_out = output.mutable_data_ptr <uint8_t >();
59
- uint8_t q_zero_point = in_zero_point.const_data_ptr <uint8_t >()[0 ];
60
63
61
64
WORD32 ret_val = xa_nn_vec_relu_asym8u_asym8u (
62
65
p_out,
63
66
p_in,
64
- ( int )q_zero_point ,
65
- out_multiplier. const_data_ptr < int32_t >()[ 0 ] ,
66
- out_shift. const_data_ptr < int32_t >()[ 0 ] ,
67
- ( int )out_zero_point ,
68
- ( int )out_zero_point ,
67
+ _in_zero_point ,
68
+ _out_multiplier ,
69
+ _out_shift ,
70
+ _out_zero_point ,
71
+ _out_zero_point ,
69
72
255 ,
70
73
input.numel ());
71
74
@@ -74,16 +77,15 @@ void quantized_relu_per_tensor_out(
74
77
} else if (input.scalar_type () == executorch::aten::ScalarType::Char) {
75
78
const int8_t * p_in = input.const_data_ptr <int8_t >();
76
79
int8_t * p_out = output.mutable_data_ptr <int8_t >();
77
- int8_t q_zero_point = in_zero_point.const_data_ptr <int8_t >()[0 ];
78
80
79
81
WORD32 ret_val = xa_nn_vec_relu_asym8s_asym8s (
80
82
p_out,
81
83
p_in,
82
- ( int )q_zero_point ,
83
- out_multiplier. const_data_ptr < int32_t >()[ 0 ] ,
84
- out_shift. const_data_ptr < int32_t >()[ 0 ] ,
85
- ( int )out_zero_point ,
86
- ( int )out_zero_point ,
84
+ _in_zero_point ,
85
+ _out_multiplier ,
86
+ _out_shift ,
87
+ _out_zero_point ,
88
+ _out_zero_point ,
87
89
127 ,
88
90
input.numel ());
89
91
@@ -97,6 +99,30 @@ void quantized_relu_per_tensor_out(
97
99
}
98
100
}
99
101
102
+ void quantized_relu_per_tensor_out (
103
+ KernelRuntimeContext& ctx,
104
+ const Tensor& input,
105
+ const Tensor& in_zero_point,
106
+ const int64_t out_zero_point,
107
+ const Tensor& out_multiplier,
108
+ const Tensor& out_shift,
109
+ Tensor& output) {
110
+ const uint8_t * p_in = input.const_data_ptr <uint8_t >();
111
+ uint8_t * p_out = output.mutable_data_ptr <uint8_t >();
112
+ uint8_t _in_zero_point = in_zero_point.const_data_ptr <uint8_t >()[0 ];
113
+ int32_t _out_multiplier = out_multiplier.const_data_ptr <int32_t >()[0 ];
114
+ int32_t _out_shift = out_shift.const_data_ptr <int32_t >()[0 ];
115
+
116
+ quantized_relu_per_tensor_out (
117
+ ctx,
118
+ input,
119
+ _in_zero_point,
120
+ out_zero_point,
121
+ _out_multiplier,
122
+ _out_shift,
123
+ output);
124
+ }
125
+
100
126
} // namespace native
101
127
} // namespace HiFi
102
128
} // namespace impl
0 commit comments