Skip to content

Commit 14df69f

Browse files
committed
Support Half/BFloat16 in scatter
Partial fix for #7748. ghstack-source-id: 40b6cdd ghstack-comment-id: 2608589452 Pull Request resolved: #7864
1 parent 26665f3 commit 14df69f

File tree

3 files changed

+5
-4
lines changed

3 files changed

+5
-4
lines changed

kernels/portable/cpu/op_scatter.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ Tensor& scatter_src_out(
125125

126126
constexpr auto name = "scatter.src_out";
127127

128-
ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
128+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
129129
scatter_src_helper<CTYPE>(in, dim, index, src, out);
130130
});
131131

@@ -158,7 +158,7 @@ Tensor& scatter_value_out(
158158
CTYPE_VAL val;
159159
utils::extract_scalar(value, &val);
160160

161-
ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
161+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
162162
scatter_value_helper<CTYPE>(in, dim, index, val, out);
163163
});
164164
});

kernels/test/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ set(all_test_sources
201201
"op_rsqrt_test.cpp"
202202
"op_rsub_test.cpp"
203203
"op_scalar_tensor_test.cpp"
204+
"op_scatter_test.cpp"
204205
"op_scatter_add_test.cpp"
205206
"op_select_scatter_test.cpp"
206207
"op_select_copy_test.cpp"

kernels/test/op_scatter_test.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ class OpScatterValueOutTest : public OperatorTest {
368368

369369
TEST_F(OpScatterSrcOutTest, AllValidInputOutputSupport) {
370370
#define TEST_ENTRY(CTYPE, DTYPE) test_scatter_src_out<ScalarType::DTYPE>();
371-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
371+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
372372
#undef TEST_ENTRY
373373
}
374374

@@ -381,7 +381,7 @@ TEST_F(OpScatterSrcOutTest, InvalidDimensionsDies) {
381381

382382
TEST_F(OpScatterValueOutTest, AllValidInputOutputSupport) {
383383
#define TEST_ENTRY(CTYPE, DTYPE) test_scatter_value_out<ScalarType::DTYPE>();
384-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
384+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
385385
#undef TEST_ENTRY
386386
}
387387

0 commit comments

Comments
 (0)