|
6 | 6 | * LICENSE file in the root directory of this source tree. |
7 | 7 | */ |
8 | 8 |
|
9 | | -#include <executorch/runtime/kernel/kernel_includes.h> |
10 | | - |
11 | 9 | #include <executorch/kernels/optimized/blas/CPUBlas.h> |
| 10 | +#include <executorch/kernels/portable/cpu/util/matmul_ops_util.h> |
| 11 | +#include <executorch/runtime/kernel/kernel_includes.h> |
12 | 12 |
|
13 | 13 | // Performs a batch matrix-matrix product of matrices stored in input and mat2. |
14 | 14 |
|
@@ -136,33 +136,32 @@ Error resize_out_tensor(const Tensor& self, const Tensor& mat2, Tensor& out) { |
136 | 136 |
|
137 | 137 | // bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) |
138 | 138 | Tensor& opt_bmm_out( |
139 | | - KernelRuntimeContext& context, |
| 139 | + KernelRuntimeContext& ctx, |
140 | 140 | const Tensor& self, |
141 | 141 | const Tensor& mat2, |
142 | 142 | Tensor& out) { |
143 | | - (void)context; |
| 143 | + (void)ctx; |
144 | 144 |
|
145 | 145 | ET_KERNEL_CHECK( |
146 | | - context, |
| 146 | + ctx, |
147 | 147 | resize_out_tensor(self, mat2, out) == Error::Ok, |
148 | 148 | InvalidArgument, |
149 | 149 | out); |
150 | 150 | ET_KERNEL_CHECK( |
151 | | - context, check_bmm_out_args(self, mat2, out), InvalidArgument, out); |
152 | | - |
153 | | -#define BMM_TENSOR(ctype, dtype) \ |
154 | | - case ScalarType::dtype: \ |
155 | | - bmm_kernel<ctype>(self, mat2, out); \ |
156 | | - break; |
157 | | - |
158 | | - auto scalar_type = self.scalar_type(); |
159 | | - switch (scalar_type) { |
160 | | - ET_FORALL_REAL_TYPES_AND(Half, BMM_TENSOR) |
161 | | - default: |
162 | | - ET_CHECK_MSG( |
163 | | - false, "Unhandled dtype %" PRId8, static_cast<int8_t>(scalar_type)); |
| 151 | + ctx, check_bmm_out_args(self, mat2, out), InvalidArgument, out); |
| 152 | + |
| 153 | + constexpr auto name = "bmm.out"; |
| 154 | + auto self_type = self.scalar_type(); |
| 155 | + |
| 156 | + if (executorch::runtime::isComplexType(self_type)) { |
| 157 | + ET_SWITCH_COMPLEXH_TYPES(self_type, ctx, name, CTYPE, [&]() { |
| 158 | + internal::bmm_out_impl<CTYPE>(self, mat2, out); |
| 159 | + }); |
| 160 | + } else { |
| 161 | + ET_SWITCH_REALH_TYPES(self_type, ctx, name, CTYPE, [&]() { |
| 162 | + bmm_kernel<CTYPE>(self, mat2, out); |
| 163 | + }); |
164 | 164 | } |
165 | | -#undef BMM_TENSOR |
166 | 165 |
|
167 | 166 | return out; |
168 | 167 | } |
|
0 commit comments