Skip to content

Commit d577310

Browse files
committed
Support Half/BFloat16 in unbind_copy
Partial fix for #7748. ghstack-source-id: edc2d44 ghstack-comment-id: 2611071343 Pull Request resolved: #7908
1 parent 5ee5f2f commit d577310

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

kernels/portable/cpu/op_unbind_copy.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ void unbind_copy_int_out(
5454
ScalarType in_type = input.scalar_type();
5555
ScalarType out_type = out[0].scalar_type();
5656

57-
ET_SWITCH_REAL_TYPES_AND(
58-
Bool, in_type, ctx, "unbind_copy.int_out", CTYPE_IN, [&]() {
59-
ET_SWITCH_REAL_TYPES_AND(
60-
Bool, out_type, ctx, "unbind_copy.int_out", CTYPE_OUT, [&]() {
57+
ET_SWITCH_REALHBF16_TYPES(
58+
in_type, ctx, "unbind_copy.int_out", CTYPE_IN, [&]() {
59+
ET_SWITCH_REALHBF16_TYPES(
60+
out_type, ctx, "unbind_copy.int_out", CTYPE_OUT, [&]() {
6161
const CTYPE_IN* const input_data =
6262
input.const_data_ptr<CTYPE_IN>();
6363
for (size_t i = 0, e = out.size(); i < e; ++i) {

kernels/test/op_unbind_copy_test.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -208,19 +208,19 @@ class OpUnbindCopyIntOutTest : public OperatorTest {
208208
*/
209209
TEST_F(OpUnbindCopyIntOutTest, Unbind1x2x3OnDim0AllRealDtypes) {
210210
#define TEST_ENTRY(ctype, dtype) test_unbind_dim0<ScalarType::dtype>();
211-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
211+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
212212
#undef TEST_ENTRY
213213
}
214214

215215
TEST_F(OpUnbindCopyIntOutTest, Unbind1x2x3OnDim1AllRealDTypes) {
216216
#define TEST_ENTRY(ctype, dtype) test_unbind_dim1<ScalarType::dtype>();
217-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
217+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
218218
#undef TEST_ENTRY
219219
}
220220

221221
TEST_F(OpUnbindCopyIntOutTest, Unbind1x2x3OnDim2AllRealDTypes) {
222222
#define TEST_ENTRY(ctype, dtype) test_unbind_dim2<ScalarType::dtype>();
223-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
223+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
224224
#undef TEST_ENTRY
225225
}
226226

0 commit comments

Comments
 (0)