Skip to content

Commit 37109d4

Browse files
swolchokYIWENX14
authored andcommitted
Support Half/BFloat16 in select_scatter (#7865)
Partial fix for #7748.
1 parent bc2de23 commit 37109d4

File tree

2 files changed

+16
-18
lines changed

2 files changed

+16
-18
lines changed

kernels/portable/cpu/op_select_scatter.cpp

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -73,22 +73,20 @@ Tensor& select_scatter_out(
7373
ScalarType in_type = in.scalar_type();
7474
ScalarType src_type = src.scalar_type();
7575

76-
ET_SWITCH_REAL_TYPES_AND(
77-
Bool, in_type, ctx, "select_scatter.out", CTYPE, [&]() {
78-
ET_SWITCH_REAL_TYPES_AND(
79-
Bool, src_type, ctx, "select_scatter.out", CTYPE_SRC, [&]() {
80-
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
81-
const CTYPE_SRC* const src_data = src.const_data_ptr<CTYPE_SRC>();
82-
83-
for (size_t i = 0; i < leading_dims; ++i) {
84-
for (size_t j = 0; j < trailing_stride; ++j) {
85-
out_data[start_offset + i * out_step + j] =
86-
convert<CTYPE, CTYPE_SRC>(
87-
src_data[i * trailing_stride + j]);
88-
}
89-
}
90-
});
91-
});
76+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "select_scatter.out", CTYPE, [&]() {
77+
ET_SWITCH_REALHBBF16_TYPES(
78+
src_type, ctx, "select_scatter.out", CTYPE_SRC, [&]() {
79+
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
80+
const CTYPE_SRC* const src_data = src.const_data_ptr<CTYPE_SRC>();
81+
82+
for (size_t i = 0; i < leading_dims; ++i) {
83+
for (size_t j = 0; j < trailing_stride; ++j) {
84+
out_data[start_offset + i * out_step + j] =
85+
convert<CTYPE, CTYPE_SRC>(src_data[i * trailing_stride + j]);
86+
}
87+
}
88+
});
89+
});
9290

9391
return out;
9492
}

kernels/test/op_select_scatter_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -501,9 +501,9 @@ TEST_F(OpSelectScatterOutTest, OutputDynamicShape) {
501501
/// zeros().
502502
TEST_F(OpSelectScatterOutTest, AllDtypesSupported) {
503503
#define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
504-
ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY);
504+
ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY);
505505
#undef TEST_ENTRY
506-
// TODO: Also add tests for half, complex, quantized, and other types. Easiest
506+
// TODO: Also add tests for complex, quantized, and other types. Easiest
507507
// way to do that would be to make TensorFactory support zeros() and ones()
508508
// for those types.
509509
}

0 commit comments

Comments
 (0)