Skip to content

Commit e290799

Browse files
committed
Support Half/BFloat16 in sum
Partial fix for #7748. ghstack-source-id: a87d99f ghstack-comment-id: 2610839064 Pull Request resolved: #7897
1 parent bebceb7 commit e290799

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

kernels/portable/cpu/op_sum.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ Tensor& sum_dim_out(
4343

4444
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
4545

46-
ET_SWITCH_REAL_TYPES_AND(
47-
Bool, in.scalar_type(), ctx, "sum.IntList_out", CTYPE_IN, [&] {
48-
ET_SWITCH_REAL_TYPES_AND(
49-
Bool, out.scalar_type(), ctx, "sum.IntList_out", CTYPE_OUT, [&] {
46+
ET_SWITCH_REALHBBF16_TYPES(
47+
in.scalar_type(), ctx, "sum.IntList_out", CTYPE_IN, [&] {
48+
ET_SWITCH_REALHBBF16_TYPES(
49+
out.scalar_type(), ctx, "sum.IntList_out", CTYPE_OUT, [&] {
5050
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
5151
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
5252
CTYPE_OUT sum = 0;

kernels/test/op_sum_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,9 +302,9 @@ TEST_F(OpSumOutTest, AllRealInputRealOutputPasses) {
302302
test_sum_dim_out_dtype<ScalarType::INPUT_DTYPE, ScalarType::OUTPUT_DTYPE>();
303303

304304
#define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \
305-
ET_FORALL_REAL_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
305+
ET_FORALL_REALHBF16_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
306306

307-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
307+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
308308
#undef TEST_ENTRY
309309
#undef TEST_KERNEL
310310
}

0 commit comments

Comments
 (0)