@@ -55,21 +55,20 @@ Tensor& stack_out(
55
55
const size_t ninputs = tensors.size ();
56
56
57
57
const auto out_type = out.scalar_type ();
58
- ET_SWITCH_REAL_TYPES_AND (Bool, out_type, ctx, " stack.out" , CTYPE_OUT, [&] {
58
+ ET_SWITCH_REALHBBF16_TYPES ( out_type, ctx, " stack.out" , CTYPE_OUT, [&] {
59
59
CTYPE_OUT* out_ptr = out.mutable_data_ptr <CTYPE_OUT>();
60
60
for (size_t i = 0 ; i < outer; ++i) {
61
61
for (size_t j = 0 ; j < ninputs; ++j) {
62
62
const auto in_type = tensors[j].scalar_type ();
63
- ET_SWITCH_REAL_TYPES_AND (
64
- Bool, in_type, ctx, " stack.out" , CTYPE_IN, [&] {
65
- const CTYPE_IN* const in_ptr =
66
- tensors[j].const_data_ptr <CTYPE_IN>() + i * inner;
63
+ ET_SWITCH_REALHBBF16_TYPES (in_type, ctx, " stack.out" , CTYPE_IN, [&] {
64
+ const CTYPE_IN* const in_ptr =
65
+ tensors[j].const_data_ptr <CTYPE_IN>() + i * inner;
67
66
68
- for (size_t k = 0 ; k < inner; ++k) {
69
- out_ptr[k] = static_cast <CTYPE_OUT>(in_ptr[k]);
70
- }
71
- out_ptr += inner;
72
- });
67
+ for (size_t k = 0 ; k < inner; ++k) {
68
+ out_ptr[k] = static_cast <CTYPE_OUT>(in_ptr[k]);
69
+ }
70
+ out_ptr += inner;
71
+ });
73
72
}
74
73
}
75
74
});
0 commit comments