From b0bea09fe846ec070d88368c3658538ff4a4ff7f Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 22 Jan 2025 16:46:23 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- kernels/portable/cpu/op_select_scatter.cpp | 30 ++++++++++------------ kernels/test/op_select_scatter_test.cpp | 4 +-- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/kernels/portable/cpu/op_select_scatter.cpp b/kernels/portable/cpu/op_select_scatter.cpp index 41e034aae02..e25e311eef3 100644 --- a/kernels/portable/cpu/op_select_scatter.cpp +++ b/kernels/portable/cpu/op_select_scatter.cpp @@ -73,22 +73,20 @@ Tensor& select_scatter_out( ScalarType in_type = in.scalar_type(); ScalarType src_type = src.scalar_type(); - ET_SWITCH_REAL_TYPES_AND( - Bool, in_type, ctx, "select_scatter.out", CTYPE, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, src_type, ctx, "select_scatter.out", CTYPE_SRC, [&]() { - CTYPE* const out_data = out.mutable_data_ptr(); - const CTYPE_SRC* const src_data = src.const_data_ptr(); - - for (size_t i = 0; i < leading_dims; ++i) { - for (size_t j = 0; j < trailing_stride; ++j) { - out_data[start_offset + i * out_step + j] = - convert( - src_data[i * trailing_stride + j]); - } - } - }); - }); + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "select_scatter.out", CTYPE, [&]() { + ET_SWITCH_REALHBBF16_TYPES( + src_type, ctx, "select_scatter.out", CTYPE_SRC, [&]() { + CTYPE* const out_data = out.mutable_data_ptr(); + const CTYPE_SRC* const src_data = src.const_data_ptr(); + + for (size_t i = 0; i < leading_dims; ++i) { + for (size_t j = 0; j < trailing_stride; ++j) { + out_data[start_offset + i * out_step + j] = + convert(src_data[i * trailing_stride + j]); + } + } + }); + }); return out; } diff --git a/kernels/test/op_select_scatter_test.cpp b/kernels/test/op_select_scatter_test.cpp index 038d00afbf0..ab465a8a363 100644 --- a/kernels/test/op_select_scatter_test.cpp +++ b/kernels/test/op_select_scatter_test.cpp @@ -501,9 +501,9 @@ TEST_F(OpSelectScatterOutTest, OutputDynamicShape) { /// zeros(). TEST_F(OpSelectScatterOutTest, AllDtypesSupported) { #define TEST_ENTRY(ctype, dtype) test_dtype(); - ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY); + ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY - // TODO: Also add tests for half, complex, quantized, and other types. Easiest + // TODO: Also add tests for complex, quantized, and other types. Easiest // way to do that would be to make TensorFactory support zeros() and ones() // for those types. }