From 997aab5f410ea37371dc24b14123c76926d07fb4 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 23 Jan 2025 11:19:45 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- kernels/portable/cpu/op_stack.cpp | 19 +++++++++---------- kernels/test/op_stack_test.cpp | 2 +- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/kernels/portable/cpu/op_stack.cpp b/kernels/portable/cpu/op_stack.cpp index d3cca7ea817..e026a47fe6b 100644 --- a/kernels/portable/cpu/op_stack.cpp +++ b/kernels/portable/cpu/op_stack.cpp @@ -55,21 +55,20 @@ Tensor& stack_out( const size_t ninputs = tensors.size(); const auto out_type = out.scalar_type(); - ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "stack.out", CTYPE_OUT, [&] { + ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "stack.out", CTYPE_OUT, [&] { CTYPE_OUT* out_ptr = out.mutable_data_ptr(); for (size_t i = 0; i < outer; ++i) { for (size_t j = 0; j < ninputs; ++j) { const auto in_type = tensors[j].scalar_type(); - ET_SWITCH_REAL_TYPES_AND( - Bool, in_type, ctx, "stack.out", CTYPE_IN, [&] { - const CTYPE_IN* const in_ptr = - tensors[j].const_data_ptr() + i * inner; + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "stack.out", CTYPE_IN, [&] { + const CTYPE_IN* const in_ptr = + tensors[j].const_data_ptr() + i * inner; - for (size_t k = 0; k < inner; ++k) { - out_ptr[k] = static_cast(in_ptr[k]); - } - out_ptr += inner; - }); + for (size_t k = 0; k < inner; ++k) { + out_ptr[k] = static_cast(in_ptr[k]); + } + out_ptr += inner; + }); } } }); diff --git a/kernels/test/op_stack_test.cpp b/kernels/test/op_stack_test.cpp index e1a666306fa..b621ee67454 100644 --- a/kernels/test/op_stack_test.cpp +++ b/kernels/test/op_stack_test.cpp @@ -276,7 +276,7 @@ TEST_F(OpStackOutTest, InsertEnd) { /// zeros(). TEST_F(OpStackOutTest, AllDtypesSupported) { #define TEST_ENTRY(ctype, dtype) test_dtype(); - ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY); + ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY // TODO: Also add tests for half, complex, quantized, and other types. Easiest // way to do that would be to make TensorFactory support zeros() and ones()