diff --git a/kernels/portable/cpu/op_scatter.cpp b/kernels/portable/cpu/op_scatter.cpp index ee9b202c6c9..31e24d95823 100644 --- a/kernels/portable/cpu/op_scatter.cpp +++ b/kernels/portable/cpu/op_scatter.cpp @@ -125,7 +125,7 @@ Tensor& scatter_src_out( constexpr auto name = "scatter.src_out"; - ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() { + ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() { scatter_src_helper(in, dim, index, src, out); }); @@ -158,7 +158,7 @@ Tensor& scatter_value_out( CTYPE_VAL val; utils::extract_scalar(value, &val); - ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() { + ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() { scatter_value_helper(in, dim, index, val, out); }); }); diff --git a/kernels/test/CMakeLists.txt b/kernels/test/CMakeLists.txt index cc50fed66de..65ec529ecdf 100644 --- a/kernels/test/CMakeLists.txt +++ b/kernels/test/CMakeLists.txt @@ -201,6 +201,7 @@ set(all_test_sources "op_rsqrt_test.cpp" "op_rsub_test.cpp" "op_scalar_tensor_test.cpp" + "op_scatter_test.cpp" "op_scatter_add_test.cpp" "op_select_scatter_test.cpp" "op_select_copy_test.cpp" diff --git a/kernels/test/op_scatter_test.cpp b/kernels/test/op_scatter_test.cpp index 83c112a8c34..7c7bc862be6 100644 --- a/kernels/test/op_scatter_test.cpp +++ b/kernels/test/op_scatter_test.cpp @@ -368,7 +368,7 @@ class OpScatterValueOutTest : public OperatorTest { TEST_F(OpScatterSrcOutTest, AllValidInputOutputSupport) { #define TEST_ENTRY(CTYPE, DTYPE) test_scatter_src_out(); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } @@ -381,7 +381,7 @@ TEST_F(OpScatterSrcOutTest, InvalidDimensionsDies) { TEST_F(OpScatterValueOutTest, AllValidInputOutputSupport) { #define TEST_ENTRY(CTYPE, DTYPE) test_scatter_value_out(); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY }