From 27373332efb9a283efadc9111e8d7212d6e10dff Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 22 Jan 2025 15:27:14 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- kernels/portable/cpu/op_prod.cpp | 8 ++--- kernels/test/op_prod_test.cpp | 60 +++++++++++++++++++++----------- 2 files changed, 44 insertions(+), 24 deletions(-) diff --git a/kernels/portable/cpu/op_prod.cpp b/kernels/portable/cpu/op_prod.cpp index 9580dee2d12..a1b9f720349 100644 --- a/kernels/portable/cpu/op_prod.cpp +++ b/kernels/portable/cpu/op_prod.cpp @@ -33,8 +33,8 @@ Tensor& prod_out( ScalarType out_type = out.scalar_type(); constexpr auto name = "prod.int_out"; - ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&] { - ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&] { + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] { + ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] { const auto data_in = in.const_data_ptr(); auto data_out = out.mutable_data_ptr(); data_out[0] = static_cast(1); @@ -73,8 +73,8 @@ Tensor& prod_int_out( ScalarType out_type = out.scalar_type(); constexpr auto name = "prod.int_out"; - ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&] { - ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&] { + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] { + ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] { CTYPE_OUT* out_data = out.mutable_data_ptr(); for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) { CTYPE_OUT prod = 1; diff --git a/kernels/test/op_prod_test.cpp b/kernels/test/op_prod_test.cpp index f96eea9564c..a774bc564c6 100644 --- a/kernels/test/op_prod_test.cpp +++ b/kernels/test/op_prod_test.cpp @@ -45,6 +45,24 @@ class OpProdOutTest : public ::testing::Test { // first. torch::executor::runtime_init(); } + + template + void test_dtype() { + TensorFactory tf; + TensorFactory< + executorch::runtime::isIntegralType(DTYPE, /*includeBool*/ true) + ? ScalarType::Long + : DTYPE> + tf_out; + + Tensor self = tf.make({2, 3}, {1, 2, 3, 4, 5, 6}); + optional dtype{}; + Tensor out = tf_out.zeros({}); + Tensor out_expected = + tf_out.make({}, {DTYPE == ScalarType::Bool ? 1 : 720}); + op_prod_out(self, dtype, out); + EXPECT_TENSOR_CLOSE(out, out_expected); + } }; class OpProdIntOutTest : public ::testing::Test { @@ -54,30 +72,32 @@ class OpProdIntOutTest : public ::testing::Test { // first. torch::executor::runtime_init(); } -}; -TEST_F(OpProdOutTest, SmokeTest) { - TensorFactory tfFloat; + template + void test_dtype() { + TensorFactory tf; - Tensor self = tfFloat.make({2, 3}, {1, 2, 3, 4, 5, 6}); - optional dtype{}; - Tensor out = tfFloat.zeros({}); - Tensor out_expected = tfFloat.make({}, {720}); - op_prod_out(self, dtype, out); - EXPECT_TENSOR_CLOSE(out, out_expected); -} + Tensor self = tf.make({2, 3}, {1, 2, 3, 4, 5, 6}); + int64_t dim = 0; + bool keepdim = false; + optional dtype{}; + Tensor out = tf.zeros({3}); + Tensor out_expected = tf.make({3}, {4, 10, 18}); + op_prod_int_out(self, dim, keepdim, dtype, out); + EXPECT_TENSOR_CLOSE(out, out_expected); + } +}; -TEST_F(OpProdIntOutTest, SmokeTest) { - TensorFactory tfFloat; +TEST_F(OpProdOutTest, SmokeTest){ +#define TEST_ENTRY(ctype, dtype) test_dtype(); + ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY) +#undef TEST_ENTRY +} - Tensor self = tfFloat.make({2, 3}, {1, 2, 3, 4, 5, 6}); - int64_t dim = 0; - bool keepdim = false; - optional dtype{}; - Tensor out = tfFloat.zeros({3}); - Tensor out_expected = tfFloat.make({3}, {4, 10, 18}); - op_prod_int_out(self, dim, keepdim, dtype, out); - EXPECT_TENSOR_CLOSE(out, out_expected); +TEST_F(OpProdIntOutTest, SmokeTest){ +#define TEST_ENTRY(ctype, dtype) test_dtype(); + ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY) +#undef TEST_ENTRY } TEST_F(OpProdIntOutTest, SmokeTestKeepdim) {