Skip to content

Commit 3ccb792

Browse files
committed
Support Half/BFloat16 in stack
Partial fix for #7748. ghstack-source-id: 4a0c4e2 ghstack-comment-id: 2610826380 Pull Request resolved: #7894
1 parent 85fbdc8 commit 3ccb792

File tree

2 files changed

+10
-11
lines changed

2 files changed

+10
-11
lines changed

kernels/portable/cpu/op_stack.cpp

+9-10
Original file line numberDiff line numberDiff line change
@@ -55,21 +55,20 @@ Tensor& stack_out(
5555
const size_t ninputs = tensors.size();
5656

5757
const auto out_type = out.scalar_type();
58-
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "stack.out", CTYPE_OUT, [&] {
58+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "stack.out", CTYPE_OUT, [&] {
5959
CTYPE_OUT* out_ptr = out.mutable_data_ptr<CTYPE_OUT>();
6060
for (size_t i = 0; i < outer; ++i) {
6161
for (size_t j = 0; j < ninputs; ++j) {
6262
const auto in_type = tensors[j].scalar_type();
63-
ET_SWITCH_REAL_TYPES_AND(
64-
Bool, in_type, ctx, "stack.out", CTYPE_IN, [&] {
65-
const CTYPE_IN* const in_ptr =
66-
tensors[j].const_data_ptr<CTYPE_IN>() + i * inner;
63+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "stack.out", CTYPE_IN, [&] {
64+
const CTYPE_IN* const in_ptr =
65+
tensors[j].const_data_ptr<CTYPE_IN>() + i * inner;
6766

68-
for (size_t k = 0; k < inner; ++k) {
69-
out_ptr[k] = static_cast<CTYPE_OUT>(in_ptr[k]);
70-
}
71-
out_ptr += inner;
72-
});
67+
for (size_t k = 0; k < inner; ++k) {
68+
out_ptr[k] = static_cast<CTYPE_OUT>(in_ptr[k]);
69+
}
70+
out_ptr += inner;
71+
});
7372
}
7473
}
7574
});

kernels/test/op_stack_test.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ TEST_F(OpStackOutTest, InsertEnd) {
276276
/// zeros().
277277
TEST_F(OpStackOutTest, AllDtypesSupported) {
278278
#define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
279-
ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY);
279+
ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY);
280280
#undef TEST_ENTRY
281281
// TODO: Also add tests for half, complex, quantized, and other types. Easiest
282282
// way to do that would be to make TensorFactory support zeros() and ones()

0 commit comments

Comments
 (0)