@@ -45,6 +45,60 @@ void quantized_relu_(
45
45
}
46
46
}
47
47
48
+ void quantized_relu_per_tensor_out (
49
+ KernelRuntimeContext& ctx,
50
+ const Tensor& input,
51
+ const int64_t in_zero_point,
52
+ const int64_t out_zero_point,
53
+ const int64_t out_multiplier,
54
+ const int64_t out_shift,
55
+ Tensor& output) {
56
+ const uint8_t _in_zero_point = static_cast <uint8_t >(in_zero_point);
57
+ const uint8_t _out_zero_point = static_cast <uint8_t >(out_zero_point);
58
+ const int32_t _out_multiplier = static_cast <int32_t >(out_multiplier);
59
+ const int32_t _out_shift = static_cast <int32_t >(out_shift);
60
+ if (input.scalar_type () == executorch::aten::ScalarType::Byte ) {
61
+ const uint8_t * p_in = input.const_data_ptr <uint8_t >();
62
+ uint8_t * p_out = output.mutable_data_ptr <uint8_t >();
63
+
64
+ WORD32 ret_val = xa_nn_vec_relu_asym8u_asym8u (
65
+ p_out,
66
+ p_in,
67
+ _in_zero_point,
68
+ _out_multiplier,
69
+ _out_shift,
70
+ _out_zero_point,
71
+ _out_zero_point,
72
+ 255 ,
73
+ input.numel ());
74
+
75
+ ET_CHECK_MSG (ret_val == 0 , " An internal error occured" );
76
+
77
+ } else if (input.scalar_type () == executorch::aten::ScalarType::Char) {
78
+ const int8_t * p_in = input.const_data_ptr <int8_t >();
79
+ int8_t * p_out = output.mutable_data_ptr <int8_t >();
80
+
81
+ WORD32 ret_val = xa_nn_vec_relu_asym8s_asym8s (
82
+ p_out,
83
+ p_in,
84
+ _in_zero_point,
85
+ _out_multiplier,
86
+ _out_shift,
87
+ _out_zero_point,
88
+ _out_zero_point,
89
+ 127 ,
90
+ input.numel ());
91
+
92
+ ET_CHECK_MSG (ret_val == 0 , " An internal error occured" );
93
+
94
+ } else {
95
+ ET_CHECK_MSG (
96
+ false ,
97
+ " Unhandled input dtype %hhd" ,
98
+ static_cast <int8_t >(input.scalar_type ()));
99
+ }
100
+ }
101
+
48
102
void quantized_relu_per_tensor_out (
49
103
KernelRuntimeContext& ctx,
50
104
const Tensor& input,
0 commit comments