diff --git a/kernels/portable/cpu/op_sum.cpp b/kernels/portable/cpu/op_sum.cpp index 3c897c3487e..f5888018ae1 100644 --- a/kernels/portable/cpu/op_sum.cpp +++ b/kernels/portable/cpu/op_sum.cpp @@ -43,10 +43,10 @@ Tensor& sum_dim_out( ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out); - ET_SWITCH_REAL_TYPES_AND( - Bool, in.scalar_type(), ctx, "sum.IntList_out", CTYPE_IN, [&] { - ET_SWITCH_REAL_TYPES_AND( - Bool, out.scalar_type(), ctx, "sum.IntList_out", CTYPE_OUT, [&] { + ET_SWITCH_REALHBBF16_TYPES( + in.scalar_type(), ctx, "sum.IntList_out", CTYPE_IN, [&] { + ET_SWITCH_REALHBBF16_TYPES( + out.scalar_type(), ctx, "sum.IntList_out", CTYPE_OUT, [&] { CTYPE_OUT* out_data = out.mutable_data_ptr(); for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) { CTYPE_OUT sum = 0; diff --git a/kernels/test/op_sum_test.cpp b/kernels/test/op_sum_test.cpp index e8af989e94b..9f1700a901d 100644 --- a/kernels/test/op_sum_test.cpp +++ b/kernels/test/op_sum_test.cpp @@ -302,9 +302,9 @@ TEST_F(OpSumOutTest, AllRealInputRealOutputPasses) { test_sum_dim_out_dtype(); #define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \ - ET_FORALL_REAL_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL); + ET_FORALL_REALHBF16_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY #undef TEST_KERNEL }