Skip to content

Commit 4e1a308

Browse files
authored
[cadence] add reference quantized_fully_connected_out
Differential Revision: D70723811 Pull Request resolved: #9020
1 parent 337d73d commit 4e1a308

File tree

3 files changed

+108
-0
lines changed

3 files changed

+108
-0
lines changed

backends/cadence/aot/functions.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,3 +238,13 @@
238238
kernels:
239239
- arg_meta: null
240240
kernel_name: impl::reference::quantized_conv_per_tensor_out
241+
242+
- func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
243+
kernels:
244+
- arg_meta: null
245+
kernel_name: impl::reference::quantized_fully_connected_out
246+
247+
- func: cadence::quantized_fully_connected.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
248+
kernels:
249+
- arg_meta: null
250+
kernel_name: impl::reference::quantized_fully_connected_per_tensor_out

backends/cadence/reference/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ add_library(
8787
"quantized_relu_out.cpp"
8888
"quantized_layer_norm.cpp"
8989
"quantize_per_tensor.cpp"
90+
"quantized_fully_connected_out.cpp"
9091
"dequantize_per_tensor.cpp"
9192
"quantized_matmul_out.cpp"
9293
"im2row_out.cpp"
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
#include <executorch/backends/cadence/reference/kernels/kernels.h>
9+
#include <executorch/backends/cadence/reference/operators/operators.h>
10+
#include <executorch/backends/cadence/reference/operators/quantized_ops.h>
11+
12+
namespace impl {
13+
namespace reference {
14+
namespace native {
15+
16+
using ::executorch::aten::optional;
17+
using ::executorch::aten::ScalarType;
18+
using ::executorch::aten::Tensor;
19+
using ::executorch::runtime::KernelRuntimeContext;
20+
21+
void quantized_fully_connected_out(
22+
__ET_UNUSED KernelRuntimeContext& ctx,
23+
const Tensor& in,
24+
const Tensor& weight,
25+
const Tensor& bias,
26+
int64_t in_zero_point,
27+
const Tensor& weight_zero_point_t,
28+
const Tensor& out_multiplier,
29+
const Tensor& out_shift,
30+
int64_t out_zero_point,
31+
__ET_UNUSED const optional<Tensor>& offset,
32+
Tensor& out) {
33+
#define typed_quantized_linear(ctype, dtype) \
34+
case ScalarType::dtype: { \
35+
quantized_linear_<ctype>( \
36+
in, \
37+
weight, \
38+
bias, \
39+
in_zero_point, \
40+
weight_zero_point_t, \
41+
out_multiplier, \
42+
out_shift, \
43+
out_zero_point, \
44+
out); \
45+
break; \
46+
}
47+
48+
ScalarType dtype = out.scalar_type();
49+
switch (dtype) {
50+
ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear);
51+
default:
52+
ET_DCHECK_MSG(
53+
false, "Unhandled dtype %s", torch::executor::toString(dtype));
54+
}
55+
#undef typed_quantized_linear
56+
}
57+
58+
void quantized_fully_connected_per_tensor_out(
59+
__ET_UNUSED KernelRuntimeContext& ctx,
60+
const Tensor& in,
61+
const Tensor& weight,
62+
const Tensor& bias,
63+
int64_t in_zero_point,
64+
int64_t weight_zero_point,
65+
int64_t out_multiplier,
66+
int64_t out_shift,
67+
int64_t out_zero_point,
68+
__ET_UNUSED const optional<Tensor>& offset,
69+
Tensor& out) {
70+
#define typed_quantized_linear(ctype, dtype) \
71+
case ScalarType::dtype: { \
72+
quantized_linear_per_tensor_<ctype>( \
73+
in, \
74+
weight, \
75+
bias, \
76+
in_zero_point, \
77+
weight_zero_point, \
78+
out_multiplier, \
79+
out_shift, \
80+
out_zero_point, \
81+
out); \
82+
break; \
83+
}
84+
85+
ScalarType dtype = out.scalar_type();
86+
switch (dtype) {
87+
ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear);
88+
default:
89+
ET_DCHECK_MSG(
90+
false, "Unhandled dtype %s", torch::executor::toString(dtype));
91+
}
92+
#undef typed_quantized_linear
93+
}
94+
95+
}; // namespace native
96+
}; // namespace reference
97+
}; // namespace impl

0 commit comments

Comments
 (0)