From f1ace777b56cba104026b59dad0d1616e8caab0f Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Fri, 28 Feb 2025 17:16:10 -0800 Subject: [PATCH 1/8] Update [ghstack-poisoned] --- kernels/portable/cpu/op_argmax.cpp | 5 ++++- kernels/portable/cpu/op_argmin.cpp | 12 +++++++++++- kernels/test/op_argmax_test.cpp | 13 +++++++++++++ kernels/test/op_argmin_test.cpp | 13 +++++++++++++ 4 files changed, 41 insertions(+), 2 deletions(-) diff --git a/kernels/portable/cpu/op_argmax.cpp b/kernels/portable/cpu/op_argmax.cpp index 5eb656d5b76..7e95c305e89 100644 --- a/kernels/portable/cpu/op_argmax.cpp +++ b/kernels/portable/cpu/op_argmax.cpp @@ -49,7 +49,10 @@ Tensor& argmax_out( for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) { std::tuple acc = reduce_over_dim( [](CTYPE v, long ix, CTYPE acc_val, long acc_ix) { - if (!std::isnan(acc_val) && (std::isnan(v) || v > acc_val)) { + // the below condition as written is equivalent to + // !isnan(accval) && (isnan(v) || v > acc_val). See + // argument in op_argmin.cpp. + if (!std::isnan(acc_val) && !(v <= acc_val)) { acc_val = v; acc_ix = ix; } diff --git a/kernels/portable/cpu/op_argmin.cpp b/kernels/portable/cpu/op_argmin.cpp index 1c4a2572ea8..8a94994c4be 100644 --- a/kernels/portable/cpu/op_argmin.cpp +++ b/kernels/portable/cpu/op_argmin.cpp @@ -49,7 +49,17 @@ Tensor& argmin_out( for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) { std::tuple acc = reduce_over_dim( [](CTYPE v, long ix, CTYPE acc_val, long acc_ix) { - if (!std::isnan(acc_val) && (std::isnan(v) || v < acc_val)) { + // the below condition as written is equivalent to !isnan(accval) && + // (isnan(v) || v < acc_val). cases: + // - if neither acc_val nor v is NaN, !(v >= acc_val) is + // trivially equivalent to v < acc_val. + // - if acc_val is NaN, the whole thing is trivially false. + // - if acc_val is not NaN and v is NaN, then v >= acc_val + // - is false because all comparisons involving NaN are + // - false, so the result is true. The result is trivially + // - true for the above condition that uses isnan(v) as + // - well. + if (!std::isnan(acc_val) && !(v >= acc_val)) { acc_val = v; acc_ix = ix; } diff --git a/kernels/test/op_argmax_test.cpp b/kernels/test/op_argmax_test.cpp index 66c79cefff7..4d68dfe88be 100644 --- a/kernels/test/op_argmax_test.cpp +++ b/kernels/test/op_argmax_test.cpp @@ -90,3 +90,16 @@ TEST_F(OpArgmaxTest, SanityCheckNullDim) { EXPECT_TENSOR_EQ(out, expected); // clang-format on } + +TEST_F(OpArgmaxTest, FirstNaNWins) { + TensorFactory tf_float; + Tensor in = tf_float.make({4}, {1, NAN, -4, NAN}); + + TensorFactory tf_long; + Tensor out = tf_long.zeros({}); + Tensor expected = tf_long.make({}, {1}); + + Tensor ret = op_argmax_out(in, {}, false, out); + EXPECT_TENSOR_EQ(out, ret); + EXPECT_TENSOR_EQ(out, expected); +} diff --git a/kernels/test/op_argmin_test.cpp b/kernels/test/op_argmin_test.cpp index 250fe4f7e1e..a0b2699a28f 100644 --- a/kernels/test/op_argmin_test.cpp +++ b/kernels/test/op_argmin_test.cpp @@ -90,3 +90,16 @@ TEST_F(OpArgminTest, SanityCheckNullDim) { EXPECT_TENSOR_EQ(out, expected); // clang-format on } + +TEST_F(OpArgminTest, FirstNaNWins) { + TensorFactory tf_float; + Tensor in = tf_float.make({4}, {1, NAN, -4, NAN}); + + TensorFactory tf_long; + Tensor out = tf_long.zeros({}); + Tensor expected = tf_long.make({}, {1}); + + Tensor ret = op_argmin_out(in, {}, false, out); + EXPECT_TENSOR_EQ(out, ret); + EXPECT_TENSOR_EQ(out, expected); +} From d3a0f675767cf527ac680f32bbaa0b0fe97e8804 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Fri, 28 Feb 2025 17:16:14 -0800 Subject: [PATCH 2/8] Update [ghstack-poisoned] --- .../cpu/util/delinearized_indexes_range.h | 108 ++++++++++++++++++ kernels/portable/cpu/util/targets.bzl | 13 +++ kernels/portable/cpu/util/test/CMakeLists.txt | 4 +- .../test/delinearized_indexes_range_test.cpp | 69 +++++++++++ kernels/portable/cpu/util/test/targets.bzl | 10 ++ test/utils/OSSTestConfig.json | 4 +- 6 files changed, 206 insertions(+), 2 deletions(-) create mode 100644 kernels/portable/cpu/util/delinearized_indexes_range.h create mode 100644 kernels/portable/cpu/util/test/delinearized_indexes_range_test.cpp diff --git a/kernels/portable/cpu/util/delinearized_indexes_range.h b/kernels/portable/cpu/util/delinearized_indexes_range.h new file mode 100644 index 00000000000..c6a9cc91bfb --- /dev/null +++ b/kernels/portable/cpu/util/delinearized_indexes_range.h @@ -0,0 +1,108 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace torch::executor { + +namespace internal { +class DelinearizedIndexesIterator { + public: + using difference_type = ssize_t; + using value_type = std::array; + using reference = const value_type&; + using pointer = const value_type*; + using iterator_category = std::forward_iterator_tag; + + DelinearizedIndexesIterator() = default; + + explicit DelinearizedIndexesIterator(const Tensor& t) + : idx_(0), dim_(t.dim()), shape_(t.sizes()) { + } + + struct make_end_t { + explicit constexpr make_end_t() = default; + }; + + DelinearizedIndexesIterator(make_end_t, const Tensor& t) + : idx_(t.numel()) {} + + bool operator==(const DelinearizedIndexesIterator& rhs) const { + return idx_ == rhs.idx_; + } + + bool operator!=(const DelinearizedIndexesIterator& rhs) const { + return !operator==(rhs); + } + + reference operator*() const { + return repr_; + } + + pointer operator->() const { + return &repr_; + } + + DelinearizedIndexesIterator& operator++() { + idx_++; + for (auto ii = dim_ - 1; ii >= 0; --ii) { + repr_[ii]++; + ET_DCHECK(repr_[ii] <= shape_[ii]); + if ET_LIKELY (repr_[ii] < shape_[ii]) { + break; + } else { + repr_[ii] = 0; + } + } + return *this; + } + + DelinearizedIndexesIterator operator++(int) { + auto it = *this; + operator++(); + return it; + } + + difference_type operator-(const DelinearizedIndexesIterator& rhs) const { + return difference_type(idx_ - rhs.idx_); + } + + private: + std::size_t idx_ = 0; + value_type repr_ = {0,}; + ssize_t dim_; + ArrayRef shape_; +}; +} // namespace internal + +class DelinearizedIndexesRange { + public: + using iterator = internal::DelinearizedIndexesIterator; + + DelinearizedIndexesRange(const Tensor& t) : + tensor_(t) {} + + iterator begin() const { + return iterator(tensor_); + } + + iterator end() { + return iterator(iterator::make_end_t(), tensor_); + } + private: + const Tensor& tensor_; +}; +} // namespace torch::executor diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index eef765d5eec..95970b600ef 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -280,6 +280,19 @@ def define_common_targets(): visibility = ["//executorch/kernels/portable/cpu/..."], ) + runtime.cxx_library( + name = "delinearized_indexes_range", + exported_headers = ["delinearized_indexes_range.h"], + deps = [ + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/core/exec_aten/util:tensor_dimension_limit", + ], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + ) + # Utility functions that can be used by operators that perform reduction for aten_mode in get_aten_mode_options(): suffix = "_aten" if aten_mode else "" diff --git a/kernels/portable/cpu/util/test/CMakeLists.txt b/kernels/portable/cpu/util/test/CMakeLists.txt index 5f81e4b6aec..76c53ea8448 100644 --- a/kernels/portable/cpu/util/test/CMakeLists.txt +++ b/kernels/portable/cpu/util/test/CMakeLists.txt @@ -19,7 +19,9 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../../..) include(${EXECUTORCH_ROOT}/build/Test.cmake) -set(_test_srcs broadcast_test.cpp reduce_test.cpp) +set(_test_srcs broadcast_test.cpp delinearized_indexes_range_test.cpp + reduce_test.cpp +) et_cxx_test( kernels_portable_cpu_util_test SOURCES ${_test_srcs} EXTRA_LIBS diff --git a/kernels/portable/cpu/util/test/delinearized_indexes_range_test.cpp b/kernels/portable/cpu/util/test/delinearized_indexes_range_test.cpp new file mode 100644 index 00000000000..395e1a74d98 --- /dev/null +++ b/kernels/portable/cpu/util/test/delinearized_indexes_range_test.cpp @@ -0,0 +1,69 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using executorch::runtime::testing::TensorFactory; +using torch::executor::DelinearizedIndexesRange; + +TEST(DelinearizedIndexesRangeTest, Empty) { + TensorFactory tf; + + Tensor a = tf.make({0}, {}); + ASSERT_EQ(a.numel(), 0); + bool loop_entered = false; + for (auto _ : DelinearizedIndexesRange(a)) { + loop_entered = true; + } + EXPECT_FALSE(loop_entered); +} + +TEST(DelinearizedIndexesRangeTest, OneD) { + TensorFactory tf; + + Tensor a = tf.zeros({5}); + DelinearizedIndexesRange r(a); + std::vector v(r.begin(), r.end()); + int idx = 0; + for (const auto& elem: v) { + EXPECT_EQ(elem[0], idx++); + } +} + +TEST(DelinearizedIndexesRangeTest, ThreeD) { + TensorFactory tf; + Tensor a = tf.zeros({3, 2, 3}); + DelinearizedIndexesRange r(a); + std::vector v(r.begin(), r.end()); + std::vector expected = { + {0, 0, 0}, + {0, 0, 1}, + {0, 0, 2}, + {0, 1, 0}, + {0, 1, 1}, + {0, 1, 2}, + {1, 0, 0}, + {1, 0, 1}, + {1, 0, 2}, + {1, 1, 0}, + {1, 1, 1}, + {1, 1, 2}, + {2, 0, 0}, + {2, 0, 1}, + {2, 0, 2}, + {2, 1, 0}, + {2, 1, 1}, + {2, 1, 2}, + }; + EXPECT_EQ(v, expected); +} diff --git a/kernels/portable/cpu/util/test/targets.bzl b/kernels/portable/cpu/util/test/targets.bzl index 28988b90dcc..25e16237361 100644 --- a/kernels/portable/cpu/util/test/targets.bzl +++ b/kernels/portable/cpu/util/test/targets.bzl @@ -12,6 +12,16 @@ def define_common_targets(): ], ) + runtime.cxx_test( + name = "delinearized_indexes_range_test", + srcs = ["delinearized_indexes_range_test.cpp"], + deps = [ + "//executorch/kernels/portable/cpu/util:delinearized_indexes_range", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/core/exec_aten/testing_util:tensor_util", + ], + ) + runtime.cxx_test( name = "reduce_test", srcs = ["reduce_test.cpp"], diff --git a/test/utils/OSSTestConfig.json b/test/utils/OSSTestConfig.json index 70cb2d2e44f..b94f11ea94e 100644 --- a/test/utils/OSSTestConfig.json +++ b/test/utils/OSSTestConfig.json @@ -7,7 +7,8 @@ "op_fast_hadamard_transform_test.cpp" ], "additional_libs": [ - "custom_ops" + "custom_ops", + "dumb_fht" ] }, { @@ -62,6 +63,7 @@ "directory": "kernels/portable/cpu/util/test", "sources": [ "broadcast_test.cpp", + "delinearized_indexes_range_test.cpp", "reduce_test.cpp" ], "additional_libs": [ From a78277df2b5e753b672b762a71110cf2c0284ab9 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Fri, 28 Feb 2025 17:16:18 -0800 Subject: [PATCH 3/8] Update [ghstack-poisoned] --- kernels/portable/cpu/util/broadcast_util.h | 57 +++++++------- kernels/portable/cpu/util/elementwise_util.h | 79 ++++++++++++-------- 2 files changed, 79 insertions(+), 57 deletions(-) diff --git a/kernels/portable/cpu/util/broadcast_util.h b/kernels/portable/cpu/util/broadcast_util.h index 10bd07baee2..47ac9da4af2 100644 --- a/kernels/portable/cpu/util/broadcast_util.h +++ b/kernels/portable/cpu/util/broadcast_util.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include #include @@ -290,23 +291,27 @@ inline void apply_binary_elementwise_fn( const CTYPE_B* const data_b = b.const_data_ptr(); CTYPE_OUT* const data_out = out.mutable_data_ptr(); - for (const auto i : c10::irange(out.numel())) { - size_t a_linear_index = i; - size_t b_linear_index = i; - - if (any_is_broadcasted) { - size_t out_indexes[kTensorDimensionLimit]; - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); - + if (any_is_broadcasted) { + size_t i = 0; + for (const auto& delinearized_indexes : DelinearizedIndexesRange(out)) { + size_t a_linear_index = i; + size_t b_linear_index = i; if (a_is_broadcasted) { - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); + a_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), a); } if (b_is_broadcasted) { - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); + b_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), b); } + + data_out[i++] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]); } + } else { + for (const auto i : c10::irange(out.numel())) { + size_t a_linear_index = i; + size_t b_linear_index = i; - data_out[i] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]); + data_out[i] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]); + } } } @@ -338,28 +343,28 @@ inline void apply_ternary_elementwise_fn( const CTYPE_C* const data_c = c.const_data_ptr(); CTYPE_OUT* const data_out = out.mutable_data_ptr(); - for (const auto i : c10::irange(out.numel())) { - size_t a_linear_index = i; - size_t b_linear_index = i; - size_t c_linear_index = i; - - if (any_is_broadcasted) { - size_t out_indexes[kTensorDimensionLimit]; - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); - + if (any_is_broadcasted) { + size_t i = 0; + for (const auto& delinearized_indexes : DelinearizedIndexesRange(out)) { + size_t a_linear_index = i; + size_t b_linear_index = i; + size_t c_linear_index = i; if (a_is_broadcasted) { - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); + a_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), a); } if (b_is_broadcasted) { - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); + b_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), b); } if (c_is_broadcasted) { - c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c); + c_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), c); } - } - data_out[i] = compute_fun( - data_a[a_linear_index], data_b[b_linear_index], data_c[c_linear_index]); + data_out[i++] = compute_fun(data_a[a_linear_index], data_b[b_linear_index], data_c[c_linear_index]); + } + } else { + for (const auto i : c10::irange(out.numel())) { + data_out[i] = compute_fun(data_a[i], data_b[i], data_c[i]); + } } } diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 778006f1b99..ee19a3640fb 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -10,6 +10,7 @@ #include #include +#include #include #include @@ -121,26 +122,33 @@ inline void apply_bitensor_elementwise_fn( char* const data_out = reinterpret_cast(out.mutable_data_ptr()); auto out_numel = out.numel(); - for (const auto i : c10::irange(out_numel)) { - size_t a_linear_index = i; - size_t b_linear_index = i; - - if (any_is_broadcasted) { - size_t out_indexes[kTensorDimensionLimit]; - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); - + if (any_is_broadcasted) { + size_t i = 0; + for (const auto& delinearized_indexes : DelinearizedIndexesRange(out)) { + size_t a_linear_index = i; + size_t b_linear_index = i; if (a_is_broadcasted) { - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); + a_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), a); } if (b_is_broadcasted) { - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); + b_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), b); } + auto result = compute_fun( + load_a_to_common(&data_a[a_linear_index * a_element_size]), + load_b_to_common(&data_b[b_linear_index * b_element_size])); + store_common_to_out(result, &data_out[i * out_element_size]); + i++; } + } else { + for (const auto i : c10::irange(out_numel)) { + size_t a_linear_index = i; + size_t b_linear_index = i; - auto result = compute_fun( - load_a_to_common(&data_a[a_linear_index * a_element_size]), - load_b_to_common(&data_b[b_linear_index * b_element_size])); - store_common_to_out(result, &data_out[i * out_element_size]); + auto result = compute_fun( + load_a_to_common(&data_a[a_linear_index * a_element_size]), + load_b_to_common(&data_b[b_linear_index * b_element_size])); + store_common_to_out(result, &data_out[i * out_element_size]); + } } } @@ -211,31 +219,40 @@ inline void apply_tritensor_elementwise_fn( char* const data_out = reinterpret_cast(out.mutable_data_ptr()); auto out_numel = out.numel(); - for (const auto i : c10::irange(out_numel)) { - size_t a_linear_index = i; - size_t b_linear_index = i; - size_t c_linear_index = i; - - if (any_is_broadcasted) { - size_t out_indexes[kTensorDimensionLimit]; - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); - + if (any_is_broadcasted) { + size_t i = 0; + for (const auto& delinearized_indexes : DelinearizedIndexesRange(out)) { + size_t a_linear_index = i; + size_t b_linear_index = i; + size_t c_linear_index = i; if (a_is_broadcasted) { - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); + a_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), a); } if (b_is_broadcasted) { - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); + b_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), b); } if (c_is_broadcasted) { - c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c); + c_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), c); } + auto result = compute_fun( + load_a_to_common(&data_a[a_linear_index * a_element_size]), + load_b_to_common(&data_b[b_linear_index * b_element_size]), + load_c_to_common(&data_c[c_linear_index * c_element_size])); + store_common_to_out(result, &data_out[i * out_element_size]); + i++; } + } else { + for (const auto i : c10::irange(out_numel)) { + size_t a_linear_index = i; + size_t b_linear_index = i; + size_t c_linear_index = i; - auto result = compute_fun( - load_a_to_common(&data_a[a_linear_index * a_element_size]), - load_b_to_common(&data_b[b_linear_index * b_element_size]), - load_c_to_common(&data_c[c_linear_index * c_element_size])); - store_common_to_out(result, &data_out[i * out_element_size]); + auto result = compute_fun( + load_a_to_common(&data_a[a_linear_index * a_element_size]), + load_b_to_common(&data_b[b_linear_index * b_element_size]), + load_c_to_common(&data_c[c_linear_index * c_element_size])); + store_common_to_out(result, &data_out[i * out_element_size]); + } } } From 6b6180c123f5d5af06f7c3fafcf945afe070f77b Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Fri, 28 Feb 2025 17:16:23 -0800 Subject: [PATCH 4/8] Update [ghstack-poisoned] --- backends/xnnpack/CMakeLists.txt | 2 +- kernels/optimized/cpu/targets.bzl | 6 ++++++ kernels/optimized/optimized.yaml | 5 +++++ kernels/test/CMakeLists.txt | 1 + 4 files changed, 13 insertions(+), 1 deletion(-) diff --git a/backends/xnnpack/CMakeLists.txt b/backends/xnnpack/CMakeLists.txt index a703d67c1b2..56b8ba96a05 100644 --- a/backends/xnnpack/CMakeLists.txt +++ b/backends/xnnpack/CMakeLists.txt @@ -139,7 +139,7 @@ if(NOT CMAKE_TOOLCHAIN_FILE MATCHES ".*(iOS|ios\.toolchain)\.cmake$") endif() target_link_libraries( - xnn_executor_runner gflags portable_ops_lib ${xnn_executor_runner_libs} + xnn_executor_runner gflags optimized_native_cpu_ops_lib ${xnn_executor_runner_libs} ) target_compile_options(xnn_executor_runner PUBLIC ${_common_compile_options}) endif() diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index 83b2c320266..dc189708992 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -95,6 +95,12 @@ _OPTIMIZED_ATEN_OPS = ( "//executorch/kernels/portable/cpu/util:broadcast_util", ], ), + op_target( + name = "op_where", + deps = [ + "//executorch/kernels/portable/cpu/util:elementwise_util", + ], + ), ) diff --git a/kernels/optimized/optimized.yaml b/kernels/optimized/optimized.yaml index fd5143b1511..4f90059aa93 100644 --- a/kernels/optimized/optimized.yaml +++ b/kernels/optimized/optimized.yaml @@ -101,3 +101,8 @@ kernels: - arg_meta: null kernel_name: torch::executor::opt_sub_scalar_out + +- op: where.self_out + kernels: + - arg_meta: null + kernel_name: torch::executor::opt_where_out diff --git a/kernels/test/CMakeLists.txt b/kernels/test/CMakeLists.txt index 24adb8d9c80..394ec241698 100644 --- a/kernels/test/CMakeLists.txt +++ b/kernels/test/CMakeLists.txt @@ -275,6 +275,7 @@ set(_optimized_kernels_test_sources "op_native_layer_norm_test.cpp" "op_neg_test.cpp" "op_sub_test.cpp" + "op_where_test.cpp" "UnaryUfuncRealHBBF16ToFloatHBF16Test.cpp" ${CMAKE_CURRENT_BINARY_DIR}/include/optimized/executorch/kernels/test/supported_features.cpp ) From c4e4541d017b4a1ade5f348bc13d6faf316aadc4 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Mon, 3 Mar 2025 14:45:49 -0800 Subject: [PATCH 5/8] Update [ghstack-poisoned] --- backends/xnnpack/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/xnnpack/CMakeLists.txt b/backends/xnnpack/CMakeLists.txt index 56b8ba96a05..a703d67c1b2 100644 --- a/backends/xnnpack/CMakeLists.txt +++ b/backends/xnnpack/CMakeLists.txt @@ -139,7 +139,7 @@ if(NOT CMAKE_TOOLCHAIN_FILE MATCHES ".*(iOS|ios\.toolchain)\.cmake$") endif() target_link_libraries( - xnn_executor_runner gflags optimized_native_cpu_ops_lib ${xnn_executor_runner_libs} + xnn_executor_runner gflags portable_ops_lib ${xnn_executor_runner_libs} ) target_compile_options(xnn_executor_runner PUBLIC ${_common_compile_options}) endif() From d3edbcb227073c5edc3ca90ea3a20a2a49e6094b Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Mon, 3 Mar 2025 15:31:25 -0800 Subject: [PATCH 6/8] Update [ghstack-poisoned] --- kernels/portable/cpu/util/targets.bzl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index c42f38fd8b0..739bc117fbf 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -70,6 +70,9 @@ def define_common_targets(): exported_headers = [ "broadcast_util.h", ], + exported_deps = [ + ":broadcast_indexes_range", + ], deps = [ ":repeat_util", "//executorch/runtime/kernel:kernel_includes", From 14a55be16b0f20ecb614aa51b1333d42f56f8e80 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 4 Mar 2025 10:27:18 -0800 Subject: [PATCH 7/8] Update [ghstack-poisoned] --- kernels/portable/cpu/util/broadcast_indexes_range.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/kernels/portable/cpu/util/broadcast_indexes_range.h b/kernels/portable/cpu/util/broadcast_indexes_range.h index bebf5a056e6..7ee4f3fb2b1 100644 --- a/kernels/portable/cpu/util/broadcast_indexes_range.h +++ b/kernels/portable/cpu/util/broadcast_indexes_range.h @@ -158,7 +158,8 @@ class BroadcastIndexesIterator { // output_dim. This is straightforwardly implementable with an // adjusted stride array that contains 0s where the padded input // shape would contain 1s. - std::array effective_input_broadcast_strides_ = {{0}}; + std::array effective_input_broadcast_strides_ = { + {{0}}}; }; } // namespace internal From e26a958d9940cca470a564d38ff86cd10e0ad5d2 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 5 Mar 2025 09:50:33 -0800 Subject: [PATCH 8/8] Update [ghstack-poisoned] --- .../test/broadcast_indexes_range_test.cpp | 80 +++++++++++++++++-- 1 file changed, 74 insertions(+), 6 deletions(-) diff --git a/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp b/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp index d1db40fca48..f147958558d 100644 --- a/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp +++ b/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp @@ -112,8 +112,66 @@ TEST(BroadcastIndexesRangeTest, OneAndTwoDExhaustive) { EXPECT_EQ(expected, actual); } -// Here we assume that the previous tests established that padding -// with leading 1s is working, and test: +// Make sure nothing is thrown off by a size-1 dim in the output: +// [] -> [1, W] +// [] -> [H, 1] +// [1] -> [1, W] +// [1] -> [H, 1] +// [W] -> [1, W] +// [1, 1] -> [1, W] +// [1, 1] -> [H, 1] +// [1, W] -> [1, W] +// [H, 1] -> [H, 1] +TEST(BroadcastIndexesRangeTest, OneAndTwoDWith1InOutputShapeExhaustive) { + TensorFactory tf; + constexpr auto H = 2; + constexpr auto W = 3; + Tensor out_row = tf.zeros({1, W}); + Tensor out_col = tf.zeros({H, 1}); + Tensor in_0d_scalar = tf.zeros({}); + Tensor in_1d_scalar = tf.zeros({1}); + Tensor in_2d_scalar = tf.zeros({1, 1}); + + Tensor in_row = tf.zeros({W}); + Tensor in_leading_one_row = tf.zeros({1, W}); + + Tensor in_col = tf.zeros({H, 1}); + + size_t idx = 0; + for (const auto + [out_idx, + in_0d_idx, + in_1d_idx, + in_2d_idx, + in_row_idx, + in_leading_one_row_idx] : + BroadcastIndexesRange<5>( + out_row, + in_0d_scalar, + in_1d_scalar, + in_2d_scalar, + in_row, + in_leading_one_row)) { + EXPECT_EQ(out_idx, idx++); + EXPECT_EQ(in_0d_idx, 0); + EXPECT_EQ(in_1d_idx, 0); + EXPECT_EQ(in_2d_idx, 0); + EXPECT_EQ(in_row_idx, out_idx); + EXPECT_EQ(in_leading_one_row_idx, out_idx); + } + + idx = 0; + for (const auto [out_idx, in_0d_idx, in_1d_idx, in_2d_idx, in_col_idx] : + BroadcastIndexesRange<4>( + out_col, in_0d_scalar, in_1d_scalar, in_2d_scalar, in_col)) { + EXPECT_EQ(out_idx, idx++); + EXPECT_EQ(in_0d_idx, 0); + EXPECT_EQ(in_1d_idx, 0); + EXPECT_EQ(in_2d_idx, 0); + EXPECT_EQ(in_col_idx, out_idx); + } +} + // [1, 1, 1] -> [C, H, W] // [C, H, 1] -> [C, H, W] // [C, 1, W] -> [C, H, W] @@ -166,11 +224,12 @@ TEST(BroadcastIndexesRangeTest, ThreeDBroadcasting) { // 4-D should generalize, but we will go ahead and test: // [N, 1, H, 1] -> [N, C, H, W] // [1, C, 1, W] -> [N, C, H, W] -TEST(BroadcastIndexesRangeTest, FourDBroadcasting) { +template +void four_d_broadcasting_test() { TensorFactory tf; - Tensor out = tf.zeros({2, 3, 4, 5}); - Tensor in_broadcast_cw = tf.zeros({2, 1, 4, 1}); - Tensor in_broadcast_nh = tf.zeros({1, 3, 1, 5}); + Tensor out = tf.zeros({N, C, H, W}); + Tensor in_broadcast_cw = tf.zeros({N, 1, H, 1}); + Tensor in_broadcast_nh = tf.zeros({1, C, 1, W}); // Writing out all the indexes would be too cumbersome, so here we // take the opportunity to mutation test against delinearize_index @@ -190,3 +249,12 @@ TEST(BroadcastIndexesRangeTest, FourDBroadcasting) { linearize_access_indexes(out_indexes, out.dim(), in_broadcast_nh)); } } + +TEST(BroadcastIndexesRangeTest, FourDBroadcasting) { + four_d_broadcasting_test<2, 3, 4, 5>(); +} + +TEST(BroadcastIndexesRangeTest, FourDBroadcastingWithOneDimsInOutput) { + four_d_broadcasting_test<2, 3, 1, 5>(); + four_d_broadcasting_test<2, 1, 3, 1>(); +}