From f1ace777b56cba104026b59dad0d1616e8caab0f Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Fri, 28 Feb 2025 17:16:10 -0800 Subject: [PATCH 01/11] Update [ghstack-poisoned] --- kernels/portable/cpu/op_argmax.cpp | 5 ++++- kernels/portable/cpu/op_argmin.cpp | 12 +++++++++++- kernels/test/op_argmax_test.cpp | 13 +++++++++++++ kernels/test/op_argmin_test.cpp | 13 +++++++++++++ 4 files changed, 41 insertions(+), 2 deletions(-) diff --git a/kernels/portable/cpu/op_argmax.cpp b/kernels/portable/cpu/op_argmax.cpp index 5eb656d5b76..7e95c305e89 100644 --- a/kernels/portable/cpu/op_argmax.cpp +++ b/kernels/portable/cpu/op_argmax.cpp @@ -49,7 +49,10 @@ Tensor& argmax_out( for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) { std::tuple acc = reduce_over_dim( [](CTYPE v, long ix, CTYPE acc_val, long acc_ix) { - if (!std::isnan(acc_val) && (std::isnan(v) || v > acc_val)) { + // the below condition as written is equivalent to + // !isnan(accval) && (isnan(v) || v > acc_val). See + // argument in op_argmin.cpp. + if (!std::isnan(acc_val) && !(v <= acc_val)) { acc_val = v; acc_ix = ix; } diff --git a/kernels/portable/cpu/op_argmin.cpp b/kernels/portable/cpu/op_argmin.cpp index 1c4a2572ea8..8a94994c4be 100644 --- a/kernels/portable/cpu/op_argmin.cpp +++ b/kernels/portable/cpu/op_argmin.cpp @@ -49,7 +49,17 @@ Tensor& argmin_out( for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) { std::tuple acc = reduce_over_dim( [](CTYPE v, long ix, CTYPE acc_val, long acc_ix) { - if (!std::isnan(acc_val) && (std::isnan(v) || v < acc_val)) { + // the below condition as written is equivalent to !isnan(accval) && + // (isnan(v) || v < acc_val). cases: + // - if neither acc_val nor v is NaN, !(v >= acc_val) is + // trivially equivalent to v < acc_val. + // - if acc_val is NaN, the whole thing is trivially false. + // - if acc_val is not NaN and v is NaN, then v >= acc_val + // - is false because all comparisons involving NaN are + // - false, so the result is true. The result is trivially + // - true for the above condition that uses isnan(v) as + // - well. + if (!std::isnan(acc_val) && !(v >= acc_val)) { acc_val = v; acc_ix = ix; } diff --git a/kernels/test/op_argmax_test.cpp b/kernels/test/op_argmax_test.cpp index 66c79cefff7..4d68dfe88be 100644 --- a/kernels/test/op_argmax_test.cpp +++ b/kernels/test/op_argmax_test.cpp @@ -90,3 +90,16 @@ TEST_F(OpArgmaxTest, SanityCheckNullDim) { EXPECT_TENSOR_EQ(out, expected); // clang-format on } + +TEST_F(OpArgmaxTest, FirstNaNWins) { + TensorFactory tf_float; + Tensor in = tf_float.make({4}, {1, NAN, -4, NAN}); + + TensorFactory tf_long; + Tensor out = tf_long.zeros({}); + Tensor expected = tf_long.make({}, {1}); + + Tensor ret = op_argmax_out(in, {}, false, out); + EXPECT_TENSOR_EQ(out, ret); + EXPECT_TENSOR_EQ(out, expected); +} diff --git a/kernels/test/op_argmin_test.cpp b/kernels/test/op_argmin_test.cpp index 250fe4f7e1e..a0b2699a28f 100644 --- a/kernels/test/op_argmin_test.cpp +++ b/kernels/test/op_argmin_test.cpp @@ -90,3 +90,16 @@ TEST_F(OpArgminTest, SanityCheckNullDim) { EXPECT_TENSOR_EQ(out, expected); // clang-format on } + +TEST_F(OpArgminTest, FirstNaNWins) { + TensorFactory tf_float; + Tensor in = tf_float.make({4}, {1, NAN, -4, NAN}); + + TensorFactory tf_long; + Tensor out = tf_long.zeros({}); + Tensor expected = tf_long.make({}, {1}); + + Tensor ret = op_argmin_out(in, {}, false, out); + EXPECT_TENSOR_EQ(out, ret); + EXPECT_TENSOR_EQ(out, expected); +} From d3a0f675767cf527ac680f32bbaa0b0fe97e8804 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Fri, 28 Feb 2025 17:16:14 -0800 Subject: [PATCH 02/11] Update [ghstack-poisoned] --- .../cpu/util/delinearized_indexes_range.h | 108 ++++++++++++++++++ kernels/portable/cpu/util/targets.bzl | 13 +++ kernels/portable/cpu/util/test/CMakeLists.txt | 4 +- .../test/delinearized_indexes_range_test.cpp | 69 +++++++++++ kernels/portable/cpu/util/test/targets.bzl | 10 ++ test/utils/OSSTestConfig.json | 4 +- 6 files changed, 206 insertions(+), 2 deletions(-) create mode 100644 kernels/portable/cpu/util/delinearized_indexes_range.h create mode 100644 kernels/portable/cpu/util/test/delinearized_indexes_range_test.cpp diff --git a/kernels/portable/cpu/util/delinearized_indexes_range.h b/kernels/portable/cpu/util/delinearized_indexes_range.h new file mode 100644 index 00000000000..c6a9cc91bfb --- /dev/null +++ b/kernels/portable/cpu/util/delinearized_indexes_range.h @@ -0,0 +1,108 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace torch::executor { + +namespace internal { +class DelinearizedIndexesIterator { + public: + using difference_type = ssize_t; + using value_type = std::array; + using reference = const value_type&; + using pointer = const value_type*; + using iterator_category = std::forward_iterator_tag; + + DelinearizedIndexesIterator() = default; + + explicit DelinearizedIndexesIterator(const Tensor& t) + : idx_(0), dim_(t.dim()), shape_(t.sizes()) { + } + + struct make_end_t { + explicit constexpr make_end_t() = default; + }; + + DelinearizedIndexesIterator(make_end_t, const Tensor& t) + : idx_(t.numel()) {} + + bool operator==(const DelinearizedIndexesIterator& rhs) const { + return idx_ == rhs.idx_; + } + + bool operator!=(const DelinearizedIndexesIterator& rhs) const { + return !operator==(rhs); + } + + reference operator*() const { + return repr_; + } + + pointer operator->() const { + return &repr_; + } + + DelinearizedIndexesIterator& operator++() { + idx_++; + for (auto ii = dim_ - 1; ii >= 0; --ii) { + repr_[ii]++; + ET_DCHECK(repr_[ii] <= shape_[ii]); + if ET_LIKELY (repr_[ii] < shape_[ii]) { + break; + } else { + repr_[ii] = 0; + } + } + return *this; + } + + DelinearizedIndexesIterator operator++(int) { + auto it = *this; + operator++(); + return it; + } + + difference_type operator-(const DelinearizedIndexesIterator& rhs) const { + return difference_type(idx_ - rhs.idx_); + } + + private: + std::size_t idx_ = 0; + value_type repr_ = {0,}; + ssize_t dim_; + ArrayRef shape_; +}; +} // namespace internal + +class DelinearizedIndexesRange { + public: + using iterator = internal::DelinearizedIndexesIterator; + + DelinearizedIndexesRange(const Tensor& t) : + tensor_(t) {} + + iterator begin() const { + return iterator(tensor_); + } + + iterator end() { + return iterator(iterator::make_end_t(), tensor_); + } + private: + const Tensor& tensor_; +}; +} // namespace torch::executor diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index eef765d5eec..95970b600ef 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -280,6 +280,19 @@ def define_common_targets(): visibility = ["//executorch/kernels/portable/cpu/..."], ) + runtime.cxx_library( + name = "delinearized_indexes_range", + exported_headers = ["delinearized_indexes_range.h"], + deps = [ + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/core/exec_aten/util:tensor_dimension_limit", + ], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + ) + # Utility functions that can be used by operators that perform reduction for aten_mode in get_aten_mode_options(): suffix = "_aten" if aten_mode else "" diff --git a/kernels/portable/cpu/util/test/CMakeLists.txt b/kernels/portable/cpu/util/test/CMakeLists.txt index 5f81e4b6aec..76c53ea8448 100644 --- a/kernels/portable/cpu/util/test/CMakeLists.txt +++ b/kernels/portable/cpu/util/test/CMakeLists.txt @@ -19,7 +19,9 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../../..) include(${EXECUTORCH_ROOT}/build/Test.cmake) -set(_test_srcs broadcast_test.cpp reduce_test.cpp) +set(_test_srcs broadcast_test.cpp delinearized_indexes_range_test.cpp + reduce_test.cpp +) et_cxx_test( kernels_portable_cpu_util_test SOURCES ${_test_srcs} EXTRA_LIBS diff --git a/kernels/portable/cpu/util/test/delinearized_indexes_range_test.cpp b/kernels/portable/cpu/util/test/delinearized_indexes_range_test.cpp new file mode 100644 index 00000000000..395e1a74d98 --- /dev/null +++ b/kernels/portable/cpu/util/test/delinearized_indexes_range_test.cpp @@ -0,0 +1,69 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using executorch::runtime::testing::TensorFactory; +using torch::executor::DelinearizedIndexesRange; + +TEST(DelinearizedIndexesRangeTest, Empty) { + TensorFactory tf; + + Tensor a = tf.make({0}, {}); + ASSERT_EQ(a.numel(), 0); + bool loop_entered = false; + for (auto _ : DelinearizedIndexesRange(a)) { + loop_entered = true; + } + EXPECT_FALSE(loop_entered); +} + +TEST(DelinearizedIndexesRangeTest, OneD) { + TensorFactory tf; + + Tensor a = tf.zeros({5}); + DelinearizedIndexesRange r(a); + std::vector v(r.begin(), r.end()); + int idx = 0; + for (const auto& elem: v) { + EXPECT_EQ(elem[0], idx++); + } +} + +TEST(DelinearizedIndexesRangeTest, ThreeD) { + TensorFactory tf; + Tensor a = tf.zeros({3, 2, 3}); + DelinearizedIndexesRange r(a); + std::vector v(r.begin(), r.end()); + std::vector expected = { + {0, 0, 0}, + {0, 0, 1}, + {0, 0, 2}, + {0, 1, 0}, + {0, 1, 1}, + {0, 1, 2}, + {1, 0, 0}, + {1, 0, 1}, + {1, 0, 2}, + {1, 1, 0}, + {1, 1, 1}, + {1, 1, 2}, + {2, 0, 0}, + {2, 0, 1}, + {2, 0, 2}, + {2, 1, 0}, + {2, 1, 1}, + {2, 1, 2}, + }; + EXPECT_EQ(v, expected); +} diff --git a/kernels/portable/cpu/util/test/targets.bzl b/kernels/portable/cpu/util/test/targets.bzl index 28988b90dcc..25e16237361 100644 --- a/kernels/portable/cpu/util/test/targets.bzl +++ b/kernels/portable/cpu/util/test/targets.bzl @@ -12,6 +12,16 @@ def define_common_targets(): ], ) + runtime.cxx_test( + name = "delinearized_indexes_range_test", + srcs = ["delinearized_indexes_range_test.cpp"], + deps = [ + "//executorch/kernels/portable/cpu/util:delinearized_indexes_range", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/core/exec_aten/testing_util:tensor_util", + ], + ) + runtime.cxx_test( name = "reduce_test", srcs = ["reduce_test.cpp"], diff --git a/test/utils/OSSTestConfig.json b/test/utils/OSSTestConfig.json index 70cb2d2e44f..b94f11ea94e 100644 --- a/test/utils/OSSTestConfig.json +++ b/test/utils/OSSTestConfig.json @@ -7,7 +7,8 @@ "op_fast_hadamard_transform_test.cpp" ], "additional_libs": [ - "custom_ops" + "custom_ops", + "dumb_fht" ] }, { @@ -62,6 +63,7 @@ "directory": "kernels/portable/cpu/util/test", "sources": [ "broadcast_test.cpp", + "delinearized_indexes_range_test.cpp", "reduce_test.cpp" ], "additional_libs": [ From a78277df2b5e753b672b762a71110cf2c0284ab9 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Fri, 28 Feb 2025 17:16:18 -0800 Subject: [PATCH 03/11] Update [ghstack-poisoned] --- kernels/portable/cpu/util/broadcast_util.h | 57 +++++++------- kernels/portable/cpu/util/elementwise_util.h | 79 ++++++++++++-------- 2 files changed, 79 insertions(+), 57 deletions(-) diff --git a/kernels/portable/cpu/util/broadcast_util.h b/kernels/portable/cpu/util/broadcast_util.h index 10bd07baee2..47ac9da4af2 100644 --- a/kernels/portable/cpu/util/broadcast_util.h +++ b/kernels/portable/cpu/util/broadcast_util.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include #include @@ -290,23 +291,27 @@ inline void apply_binary_elementwise_fn( const CTYPE_B* const data_b = b.const_data_ptr(); CTYPE_OUT* const data_out = out.mutable_data_ptr(); - for (const auto i : c10::irange(out.numel())) { - size_t a_linear_index = i; - size_t b_linear_index = i; - - if (any_is_broadcasted) { - size_t out_indexes[kTensorDimensionLimit]; - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); - + if (any_is_broadcasted) { + size_t i = 0; + for (const auto& delinearized_indexes : DelinearizedIndexesRange(out)) { + size_t a_linear_index = i; + size_t b_linear_index = i; if (a_is_broadcasted) { - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); + a_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), a); } if (b_is_broadcasted) { - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); + b_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), b); } + + data_out[i++] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]); } + } else { + for (const auto i : c10::irange(out.numel())) { + size_t a_linear_index = i; + size_t b_linear_index = i; - data_out[i] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]); + data_out[i] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]); + } } } @@ -338,28 +343,28 @@ inline void apply_ternary_elementwise_fn( const CTYPE_C* const data_c = c.const_data_ptr(); CTYPE_OUT* const data_out = out.mutable_data_ptr(); - for (const auto i : c10::irange(out.numel())) { - size_t a_linear_index = i; - size_t b_linear_index = i; - size_t c_linear_index = i; - - if (any_is_broadcasted) { - size_t out_indexes[kTensorDimensionLimit]; - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); - + if (any_is_broadcasted) { + size_t i = 0; + for (const auto& delinearized_indexes : DelinearizedIndexesRange(out)) { + size_t a_linear_index = i; + size_t b_linear_index = i; + size_t c_linear_index = i; if (a_is_broadcasted) { - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); + a_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), a); } if (b_is_broadcasted) { - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); + b_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), b); } if (c_is_broadcasted) { - c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c); + c_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), c); } - } - data_out[i] = compute_fun( - data_a[a_linear_index], data_b[b_linear_index], data_c[c_linear_index]); + data_out[i++] = compute_fun(data_a[a_linear_index], data_b[b_linear_index], data_c[c_linear_index]); + } + } else { + for (const auto i : c10::irange(out.numel())) { + data_out[i] = compute_fun(data_a[i], data_b[i], data_c[i]); + } } } diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 778006f1b99..ee19a3640fb 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -10,6 +10,7 @@ #include #include +#include #include #include @@ -121,26 +122,33 @@ inline void apply_bitensor_elementwise_fn( char* const data_out = reinterpret_cast(out.mutable_data_ptr()); auto out_numel = out.numel(); - for (const auto i : c10::irange(out_numel)) { - size_t a_linear_index = i; - size_t b_linear_index = i; - - if (any_is_broadcasted) { - size_t out_indexes[kTensorDimensionLimit]; - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); - + if (any_is_broadcasted) { + size_t i = 0; + for (const auto& delinearized_indexes : DelinearizedIndexesRange(out)) { + size_t a_linear_index = i; + size_t b_linear_index = i; if (a_is_broadcasted) { - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); + a_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), a); } if (b_is_broadcasted) { - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); + b_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), b); } + auto result = compute_fun( + load_a_to_common(&data_a[a_linear_index * a_element_size]), + load_b_to_common(&data_b[b_linear_index * b_element_size])); + store_common_to_out(result, &data_out[i * out_element_size]); + i++; } + } else { + for (const auto i : c10::irange(out_numel)) { + size_t a_linear_index = i; + size_t b_linear_index = i; - auto result = compute_fun( - load_a_to_common(&data_a[a_linear_index * a_element_size]), - load_b_to_common(&data_b[b_linear_index * b_element_size])); - store_common_to_out(result, &data_out[i * out_element_size]); + auto result = compute_fun( + load_a_to_common(&data_a[a_linear_index * a_element_size]), + load_b_to_common(&data_b[b_linear_index * b_element_size])); + store_common_to_out(result, &data_out[i * out_element_size]); + } } } @@ -211,31 +219,40 @@ inline void apply_tritensor_elementwise_fn( char* const data_out = reinterpret_cast(out.mutable_data_ptr()); auto out_numel = out.numel(); - for (const auto i : c10::irange(out_numel)) { - size_t a_linear_index = i; - size_t b_linear_index = i; - size_t c_linear_index = i; - - if (any_is_broadcasted) { - size_t out_indexes[kTensorDimensionLimit]; - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); - + if (any_is_broadcasted) { + size_t i = 0; + for (const auto& delinearized_indexes : DelinearizedIndexesRange(out)) { + size_t a_linear_index = i; + size_t b_linear_index = i; + size_t c_linear_index = i; if (a_is_broadcasted) { - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); + a_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), a); } if (b_is_broadcasted) { - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); + b_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), b); } if (c_is_broadcasted) { - c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c); + c_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), c); } + auto result = compute_fun( + load_a_to_common(&data_a[a_linear_index * a_element_size]), + load_b_to_common(&data_b[b_linear_index * b_element_size]), + load_c_to_common(&data_c[c_linear_index * c_element_size])); + store_common_to_out(result, &data_out[i * out_element_size]); + i++; } + } else { + for (const auto i : c10::irange(out_numel)) { + size_t a_linear_index = i; + size_t b_linear_index = i; + size_t c_linear_index = i; - auto result = compute_fun( - load_a_to_common(&data_a[a_linear_index * a_element_size]), - load_b_to_common(&data_b[b_linear_index * b_element_size]), - load_c_to_common(&data_c[c_linear_index * c_element_size])); - store_common_to_out(result, &data_out[i * out_element_size]); + auto result = compute_fun( + load_a_to_common(&data_a[a_linear_index * a_element_size]), + load_b_to_common(&data_b[b_linear_index * b_element_size]), + load_c_to_common(&data_c[c_linear_index * c_element_size])); + store_common_to_out(result, &data_out[i * out_element_size]); + } } } From 6b6180c123f5d5af06f7c3fafcf945afe070f77b Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Fri, 28 Feb 2025 17:16:23 -0800 Subject: [PATCH 04/11] Update [ghstack-poisoned] --- backends/xnnpack/CMakeLists.txt | 2 +- kernels/optimized/cpu/targets.bzl | 6 ++++++ kernels/optimized/optimized.yaml | 5 +++++ kernels/test/CMakeLists.txt | 1 + 4 files changed, 13 insertions(+), 1 deletion(-) diff --git a/backends/xnnpack/CMakeLists.txt b/backends/xnnpack/CMakeLists.txt index a703d67c1b2..56b8ba96a05 100644 --- a/backends/xnnpack/CMakeLists.txt +++ b/backends/xnnpack/CMakeLists.txt @@ -139,7 +139,7 @@ if(NOT CMAKE_TOOLCHAIN_FILE MATCHES ".*(iOS|ios\.toolchain)\.cmake$") endif() target_link_libraries( - xnn_executor_runner gflags portable_ops_lib ${xnn_executor_runner_libs} + xnn_executor_runner gflags optimized_native_cpu_ops_lib ${xnn_executor_runner_libs} ) target_compile_options(xnn_executor_runner PUBLIC ${_common_compile_options}) endif() diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index 83b2c320266..dc189708992 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -95,6 +95,12 @@ _OPTIMIZED_ATEN_OPS = ( "//executorch/kernels/portable/cpu/util:broadcast_util", ], ), + op_target( + name = "op_where", + deps = [ + "//executorch/kernels/portable/cpu/util:elementwise_util", + ], + ), ) diff --git a/kernels/optimized/optimized.yaml b/kernels/optimized/optimized.yaml index fd5143b1511..4f90059aa93 100644 --- a/kernels/optimized/optimized.yaml +++ b/kernels/optimized/optimized.yaml @@ -101,3 +101,8 @@ kernels: - arg_meta: null kernel_name: torch::executor::opt_sub_scalar_out + +- op: where.self_out + kernels: + - arg_meta: null + kernel_name: torch::executor::opt_where_out diff --git a/kernels/test/CMakeLists.txt b/kernels/test/CMakeLists.txt index 24adb8d9c80..394ec241698 100644 --- a/kernels/test/CMakeLists.txt +++ b/kernels/test/CMakeLists.txt @@ -275,6 +275,7 @@ set(_optimized_kernels_test_sources "op_native_layer_norm_test.cpp" "op_neg_test.cpp" "op_sub_test.cpp" + "op_where_test.cpp" "UnaryUfuncRealHBBF16ToFloatHBF16Test.cpp" ${CMAKE_CURRENT_BINARY_DIR}/include/optimized/executorch/kernels/test/supported_features.cpp ) From c4e4541d017b4a1ade5f348bc13d6faf316aadc4 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Mon, 3 Mar 2025 14:45:49 -0800 Subject: [PATCH 05/11] Update [ghstack-poisoned] --- backends/xnnpack/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/xnnpack/CMakeLists.txt b/backends/xnnpack/CMakeLists.txt index 56b8ba96a05..a703d67c1b2 100644 --- a/backends/xnnpack/CMakeLists.txt +++ b/backends/xnnpack/CMakeLists.txt @@ -139,7 +139,7 @@ if(NOT CMAKE_TOOLCHAIN_FILE MATCHES ".*(iOS|ios\.toolchain)\.cmake$") endif() target_link_libraries( - xnn_executor_runner gflags optimized_native_cpu_ops_lib ${xnn_executor_runner_libs} + xnn_executor_runner gflags portable_ops_lib ${xnn_executor_runner_libs} ) target_compile_options(xnn_executor_runner PUBLIC ${_common_compile_options}) endif() From 18583c3ba44e671a6b6aca686799c51985e1a8e7 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Mon, 3 Mar 2025 14:45:50 -0800 Subject: [PATCH 06/11] Update [ghstack-poisoned] --- backends/xnnpack/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/xnnpack/CMakeLists.txt b/backends/xnnpack/CMakeLists.txt index a703d67c1b2..56b8ba96a05 100644 --- a/backends/xnnpack/CMakeLists.txt +++ b/backends/xnnpack/CMakeLists.txt @@ -139,7 +139,7 @@ if(NOT CMAKE_TOOLCHAIN_FILE MATCHES ".*(iOS|ios\.toolchain)\.cmake$") endif() target_link_libraries( - xnn_executor_runner gflags portable_ops_lib ${xnn_executor_runner_libs} + xnn_executor_runner gflags optimized_native_cpu_ops_lib ${xnn_executor_runner_libs} ) target_compile_options(xnn_executor_runner PUBLIC ${_common_compile_options}) endif() From d3edbcb227073c5edc3ca90ea3a20a2a49e6094b Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Mon, 3 Mar 2025 15:31:25 -0800 Subject: [PATCH 07/11] Update [ghstack-poisoned] --- kernels/portable/cpu/util/targets.bzl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index c42f38fd8b0..739bc117fbf 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -70,6 +70,9 @@ def define_common_targets(): exported_headers = [ "broadcast_util.h", ], + exported_deps = [ + ":broadcast_indexes_range", + ], deps = [ ":repeat_util", "//executorch/runtime/kernel:kernel_includes", From a93cd69438a6ed127eac41b449ef20fea332f8dc Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Mon, 3 Mar 2025 16:25:44 -0800 Subject: [PATCH 08/11] Update [ghstack-poisoned] --- kernels/optimized/cpu/op_where.cpp | 97 --------- kernels/optimized/cpu/targets.bzl | 6 - kernels/optimized/optimized.yaml | 5 - .../cpu/util/broadcast_indexes_range.h | 206 ------------------ kernels/portable/cpu/util/broadcast_util.h | 56 +++-- kernels/portable/cpu/util/elementwise_util.h | 81 +++---- kernels/portable/cpu/util/targets.bzl | 16 -- kernels/portable/cpu/util/test/CMakeLists.txt | 4 +- .../test/broadcast_indexes_range_test.cpp | 174 --------------- kernels/portable/cpu/util/test/targets.bzl | 11 - kernels/test/CMakeLists.txt | 1 - test/utils/OSSTestConfig.json | 4 +- 12 files changed, 81 insertions(+), 580 deletions(-) delete mode 100644 kernels/optimized/cpu/op_where.cpp delete mode 100644 kernels/portable/cpu/util/broadcast_indexes_range.h delete mode 100644 kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp diff --git a/kernels/optimized/cpu/op_where.cpp b/kernels/optimized/cpu/op_where.cpp deleted file mode 100644 index 4d897ea6281..00000000000 --- a/kernels/optimized/cpu/op_where.cpp +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ -#include -#include -#include - -namespace torch { -namespace executor { -namespace native { - -Tensor& opt_where_out( - KernelRuntimeContext& ctx, - const Tensor& cond, - const Tensor& a, - const Tensor& b, - Tensor& out) { - // Common Dtype - ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); - - // Check Common Dtype - ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out); - - // Check Dim Order - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(cond, a, b, out), InvalidArgument, out); - - // Resize - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, cond, out) == Error::Ok, - InvalidArgument, - out); - - // Compute Dtype - ScalarType compute_type = utils::get_compute_type(common_type); - - // @lint-ignore CLANGTIDY facebook-hte-CArray - static constexpr const char op_name[] = "where.self_out"; - - if (a.scalar_type() == b.scalar_type() && - a.scalar_type() == out.scalar_type() && a.scalar_type() == compute_type && - // Using a Byte tensor for cond has been deprecated for a long time. - cond.scalar_type() == ScalarType::Bool) { - auto out_numel = out.numel(); - ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - const bool a_is_broadcasted = !out.sizes().equals(a.sizes()); - const bool b_is_broadcasted = !out.sizes().equals(b.sizes()); - const bool cond_is_broadcasted = !out.sizes().equals(cond.sizes()); - const bool any_is_broadcasted = - (a_is_broadcasted || b_is_broadcasted || cond_is_broadcasted); - const CTYPE_COMPUTE* const data_a = a.const_data_ptr(); - const CTYPE_COMPUTE* const data_b = b.const_data_ptr(); - const bool* const data_cond = cond.const_data_ptr(); - CTYPE_COMPUTE* const data_out = out.data_ptr(); - if (any_is_broadcasted) { - for (const auto [out_index, a_index, b_index, cond_index] : - BroadcastIndexesRange<3>(out, a, b, cond)) { - data_out[out_index] = - data_cond[cond_index] ? data_a[a_index] : data_b[b_index]; - } - } else { - for (const auto i : c10::irange(out_numel)) { - data_out[i] = data_cond[i] ? data_a[i] : data_b[i]; - } - } - }); - } else { - // Fall back for mixed dtype to keep code size and compile time - // reasonable. - ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_tritensor_elementwise_fn( - [](const CTYPE_COMPUTE val_a, - const CTYPE_COMPUTE val_b, - const CTYPE_COMPUTE val_c) { return val_c ? val_a : val_b; }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBBF16, - b, - utils::SupportedTensorDtypes::REALHBBF16, - cond, - utils::SupportedTensorDtypes::BOOL_OR_BYTE, - out, - utils::SupportedTensorDtypes::SAME_AS_COMMON); - }); - } - - return out; -} - -} // namespace native -} // namespace executor -} // namespace torch diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index dc189708992..83b2c320266 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -95,12 +95,6 @@ _OPTIMIZED_ATEN_OPS = ( "//executorch/kernels/portable/cpu/util:broadcast_util", ], ), - op_target( - name = "op_where", - deps = [ - "//executorch/kernels/portable/cpu/util:elementwise_util", - ], - ), ) diff --git a/kernels/optimized/optimized.yaml b/kernels/optimized/optimized.yaml index 4f90059aa93..fd5143b1511 100644 --- a/kernels/optimized/optimized.yaml +++ b/kernels/optimized/optimized.yaml @@ -101,8 +101,3 @@ kernels: - arg_meta: null kernel_name: torch::executor::opt_sub_scalar_out - -- op: where.self_out - kernels: - - arg_meta: null - kernel_name: torch::executor::opt_where_out diff --git a/kernels/portable/cpu/util/broadcast_indexes_range.h b/kernels/portable/cpu/util/broadcast_indexes_range.h deleted file mode 100644 index bebf5a056e6..00000000000 --- a/kernels/portable/cpu/util/broadcast_indexes_range.h +++ /dev/null @@ -1,206 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#include -#include -#include -#include -#include - -#include -#include - -namespace torch::executor { - -namespace internal { -template -class BroadcastIndexesIterator { - public: - using difference_type = ssize_t; - using value_type = std::array; - using reference = const value_type&; - using pointer = const value_type*; - using iterator_category = std::forward_iterator_tag; - - BroadcastIndexesIterator() = default; - - template - explicit BroadcastIndexesIterator(const Tensor& output, const Args&... args) - : output_dim_(output.dim()), - output_shape_(output.sizes()), - effective_input_broadcast_strides_{ - effective_input_broadcast_stride(output, args)...} { - static_assert( - sizeof...(args) == kNumInputs && (std::is_same_v && ...), - "BroadcastIndexesIterator constructor requires kNumInputs input tensor" - "arguments!"); - } - - struct make_end_t { - explicit constexpr make_end_t() = default; - }; - - template - BroadcastIndexesIterator(make_end_t, const Tensor& t, const Args&... args) - : current_indexes_{ - t.numel(), - 0, - } {} - - bool operator==(const BroadcastIndexesIterator& rhs) const { - return output_index() == rhs.output_index(); - } - - bool operator!=(const BroadcastIndexesIterator& rhs) const { - return !operator==(rhs); - } - - reference operator*() const { - return current_indexes_; - } - - pointer operator->() const { - return ¤t_indexes_; - } - - BroadcastIndexesIterator& operator++() { - output_index()++; - // TODO: add optimization for particular input tensors not being - // broadcasted? - for (auto ii = output_dim_ - 1; ii >= 0; --ii) { - // You might wonder what happens if output_shape_[ii] == 0. In that case, - // output.numel() would be 0, and thus the iterator would be the end() - // iterator, which is not legal to increment. - if ET_UNLIKELY (delinearized_output_index_[ii] == output_shape_[ii] - 1) { - const auto old_delinearized_output_index_item = - delinearized_output_index_[ii]; - delinearized_output_index_[ii] = 0; - for (const auto jj : c10::irange(1, kNumInputs + 1)) { - current_indexes_[jj] -= old_delinearized_output_index_item * - effective_input_broadcast_strides_[jj - 1][ii]; - } - } else { - delinearized_output_index_[ii]++; - for (const auto jj : c10::irange(1, kNumInputs + 1)) { - current_indexes_.at(jj) += - effective_input_broadcast_strides_[jj - 1][ii]; - } - break; - } - } - return *this; - } - - BroadcastIndexesIterator operator++(int) { - auto it = *this; - operator++(); - return it; - } - - difference_type operator-(const BroadcastIndexesIterator& rhs) const { - return difference_type(output_index() - rhs.output_index()); - } - - private: - ssize_t output_index() const { - return current_indexes_[0]; - } - - ssize_t& output_index() { - return current_indexes_[0]; - } - - std::array - effective_input_broadcast_stride(const Tensor& output, const Tensor& t) - const { - std::array - result = {0}; - ET_CHECK_MSG( - t.dim() <= output.dim(), - "input to broadcasting op should have dim at most output dim, but %d > %d!", - (int)t.dim(), - (int)output.dim()); - - const auto num_leading_ones = output.dim() - t.dim(); - for (const auto idx : c10::irange(num_leading_ones)) { - result[idx] = 0; - } - const auto t_sizes = t.sizes(); - const auto t_strides = t.strides(); - for (const auto idx : - c10::irange(num_leading_ones, num_leading_ones + t.dim())) { - result[idx] = t_sizes[idx - num_leading_ones] == 1 - ? 0 - : t_strides[idx - num_leading_ones]; - } - return result; - } - - // The 0th entry is the current linear index into the output, - // followed by kNumInputs input indexes. - std::array current_indexes_ = {0}; - using ShapeType = std:: - array; - ShapeType delinearized_output_index_ = {0}; - ssize_t output_dim_; - ArrayRef output_shape_; - // The linear index for a broadcast tensor is - // sum(delinearized_output_index_[i] * input_stride_[i] if - // padded_input_shape_[i] != 1 else 0), where padded_input_shape is - // input.sizes() with leading 1s added to make its size equal to - // output_dim. This is straightforwardly implementable with an - // adjusted stride array that contains 0s where the padded input - // shape would contain 1s. - std::array effective_input_broadcast_strides_ = {{0}}; -}; -} // namespace internal - -// Efficient mechanism for looping over the index space for an output -// tensor and kNumInputs possibly-broadcasted input tensors. Use as follows: -// -// auto* output_data = output.mutable_data_ptr(); -// const auto* a_data = a.mutable_data_ptr(); -// const auto* b_data = b.mutable_data_ptr(); -// for (const auto [output_index, a_index, b_index] : -// BroadcastIndexesRange<2>(output, a, b)) { -// // Access output_data[output_index], a_data[a_index], and b_data[b_index]. -// } -// -// (where OutputType, AType, and BType are known concrete types.) -// -// Unlike looping using delinearize_index() and -// linearize_access_indexes(), BroadcastIndexesRange avoids expensive -// division and modulo operations on each iteration. -template -class BroadcastIndexesRange { - public: - using iterator = internal::BroadcastIndexesIterator; - - template - BroadcastIndexesRange(const Tensor& output, const Args&... args) - : tensors_{&output, (&args)...} {} - - iterator begin() const { - return std::apply( - [](const auto&... args) { return iterator((*args)...); }, tensors_); - } - - iterator end() const { - return std::apply( - [](const auto&... args) { - return iterator(typename iterator::make_end_t(), (*args)...); - }, - tensors_); - } - - private: - std::array tensors_; -}; -} // namespace torch::executor diff --git a/kernels/portable/cpu/util/broadcast_util.h b/kernels/portable/cpu/util/broadcast_util.h index f6bfae9bdaa..10bd07baee2 100644 --- a/kernels/portable/cpu/util/broadcast_util.h +++ b/kernels/portable/cpu/util/broadcast_util.h @@ -9,7 +9,6 @@ #pragma once #include -#include #include #include @@ -291,18 +290,23 @@ inline void apply_binary_elementwise_fn( const CTYPE_B* const data_b = b.const_data_ptr(); CTYPE_OUT* const data_out = out.mutable_data_ptr(); - if (any_is_broadcasted) { - for (const auto [out_index, a_index, b_index] : - BroadcastIndexesRange<2>(out, a, b)) { - data_out[out_index] = compute_fun(data_a[a_index], data_b[b_index]); - } - } else { - for (const auto i : c10::irange(out.numel())) { - size_t a_linear_index = i; - size_t b_linear_index = i; + for (const auto i : c10::irange(out.numel())) { + size_t a_linear_index = i; + size_t b_linear_index = i; + + if (any_is_broadcasted) { + size_t out_indexes[kTensorDimensionLimit]; + delinearize_index(i, out, out_indexes, kTensorDimensionLimit); - data_out[i] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]); + if (a_is_broadcasted) { + a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); + } + if (b_is_broadcasted) { + b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); + } } + + data_out[i] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]); } } @@ -334,16 +338,28 @@ inline void apply_ternary_elementwise_fn( const CTYPE_C* const data_c = c.const_data_ptr(); CTYPE_OUT* const data_out = out.mutable_data_ptr(); - if (any_is_broadcasted) { - for (const auto [out_index, a_index, b_index, c_index] : - BroadcastIndexesRange<3>(out, a, b, c)) { - data_out[out_index] = - compute_fun(data_a[a_index], data_b[b_index], data_c[c_index]); - } - } else { - for (const auto i : c10::irange(out.numel())) { - data_out[i] = compute_fun(data_a[i], data_b[i], data_c[i]); + for (const auto i : c10::irange(out.numel())) { + size_t a_linear_index = i; + size_t b_linear_index = i; + size_t c_linear_index = i; + + if (any_is_broadcasted) { + size_t out_indexes[kTensorDimensionLimit]; + delinearize_index(i, out, out_indexes, kTensorDimensionLimit); + + if (a_is_broadcasted) { + a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); + } + if (b_is_broadcasted) { + b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); + } + if (c_is_broadcasted) { + c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c); + } } + + data_out[i] = compute_fun( + data_a[a_linear_index], data_b[b_linear_index], data_c[c_linear_index]); } } diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 09db5f7180d..778006f1b99 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -9,7 +9,6 @@ #pragma once #include -#include #include #include #include @@ -122,24 +121,26 @@ inline void apply_bitensor_elementwise_fn( char* const data_out = reinterpret_cast(out.mutable_data_ptr()); auto out_numel = out.numel(); - if (any_is_broadcasted) { - for (const auto [out_index, a_index, b_index] : - BroadcastIndexesRange<2>(out, a, b)) { - auto result = compute_fun( - load_a_to_common(&data_a[a_index * a_element_size]), - load_b_to_common(&data_b[b_index * b_element_size])); - store_common_to_out(result, &data_out[out_index * out_element_size]); - } - } else { - for (const auto i : c10::irange(out_numel)) { - size_t a_linear_index = i; - size_t b_linear_index = i; - - auto result = compute_fun( - load_a_to_common(&data_a[a_linear_index * a_element_size]), - load_b_to_common(&data_b[b_linear_index * b_element_size])); - store_common_to_out(result, &data_out[i * out_element_size]); + for (const auto i : c10::irange(out_numel)) { + size_t a_linear_index = i; + size_t b_linear_index = i; + + if (any_is_broadcasted) { + size_t out_indexes[kTensorDimensionLimit]; + delinearize_index(i, out, out_indexes, kTensorDimensionLimit); + + if (a_is_broadcasted) { + a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); + } + if (b_is_broadcasted) { + b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); + } } + + auto result = compute_fun( + load_a_to_common(&data_a[a_linear_index * a_element_size]), + load_b_to_common(&data_b[b_linear_index * b_element_size])); + store_common_to_out(result, &data_out[i * out_element_size]); } } @@ -210,27 +211,31 @@ inline void apply_tritensor_elementwise_fn( char* const data_out = reinterpret_cast(out.mutable_data_ptr()); auto out_numel = out.numel(); - if (any_is_broadcasted) { - for (const auto [out_index, a_index, b_index, c_index] : - BroadcastIndexesRange<3>(out, a, b, c)) { - auto result = compute_fun( - load_a_to_common(&data_a[a_index * a_element_size]), - load_b_to_common(&data_b[b_index * b_element_size]), - load_c_to_common(&data_c[c_index * c_element_size])); - store_common_to_out(result, &data_out[out_index * out_element_size]); - } - } else { - for (const auto i : c10::irange(out_numel)) { - size_t a_linear_index = i; - size_t b_linear_index = i; - size_t c_linear_index = i; - - auto result = compute_fun( - load_a_to_common(&data_a[a_linear_index * a_element_size]), - load_b_to_common(&data_b[b_linear_index * b_element_size]), - load_c_to_common(&data_c[c_linear_index * c_element_size])); - store_common_to_out(result, &data_out[i * out_element_size]); + for (const auto i : c10::irange(out_numel)) { + size_t a_linear_index = i; + size_t b_linear_index = i; + size_t c_linear_index = i; + + if (any_is_broadcasted) { + size_t out_indexes[kTensorDimensionLimit]; + delinearize_index(i, out, out_indexes, kTensorDimensionLimit); + + if (a_is_broadcasted) { + a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); + } + if (b_is_broadcasted) { + b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); + } + if (c_is_broadcasted) { + c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c); + } } + + auto result = compute_fun( + load_a_to_common(&data_a[a_linear_index * a_element_size]), + load_b_to_common(&data_b[b_linear_index * b_element_size]), + load_c_to_common(&data_c[c_linear_index * c_element_size])); + store_common_to_out(result, &data_out[i * out_element_size]); } } diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index 739bc117fbf..2b22687274f 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -70,9 +70,6 @@ def define_common_targets(): exported_headers = [ "broadcast_util.h", ], - exported_deps = [ - ":broadcast_indexes_range", - ], deps = [ ":repeat_util", "//executorch/runtime/kernel:kernel_includes", @@ -281,19 +278,6 @@ def define_common_targets(): visibility = ["//executorch/kernels/portable/cpu/..."], ) - runtime.cxx_library( - name = "broadcast_indexes_range", - exported_headers = ["broadcast_indexes_range.h"], - deps = [ - "//executorch/runtime/core/exec_aten:lib", - "//executorch/runtime/core/exec_aten/util:tensor_dimension_limit", - ], - visibility = [ - "//executorch/...", - "@EXECUTORCH_CLIENTS", - ], - ) - # Utility functions that can be used by operators that perform reduction for aten_mode in get_aten_mode_options(): suffix = "_aten" if aten_mode else "" diff --git a/kernels/portable/cpu/util/test/CMakeLists.txt b/kernels/portable/cpu/util/test/CMakeLists.txt index b92e8ebfae1..5f81e4b6aec 100644 --- a/kernels/portable/cpu/util/test/CMakeLists.txt +++ b/kernels/portable/cpu/util/test/CMakeLists.txt @@ -19,9 +19,7 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../../..) include(${EXECUTORCH_ROOT}/build/Test.cmake) -set(_test_srcs broadcast_indexes_range_test.cpp broadcast_test.cpp - reduce_test.cpp -) +set(_test_srcs broadcast_test.cpp reduce_test.cpp) et_cxx_test( kernels_portable_cpu_util_test SOURCES ${_test_srcs} EXTRA_LIBS diff --git a/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp b/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp deleted file mode 100644 index 9e17e676891..00000000000 --- a/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp +++ /dev/null @@ -1,174 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -#include - -using executorch::aten::ScalarType; -using executorch::aten::Tensor; -using executorch::runtime::testing::TensorFactory; -using torch::executor::BroadcastIndexesRange; -using torch::executor::delinearize_index; -using torch::executor::linearize_access_indexes; - -namespace { -template -auto range_to_vec(const Range& rng) { - return std::vector( - rng.begin(), rng.end()); -} -} // namespace -TEST(BroadcastIndexesRangeTest, Empty) { - TensorFactory tf; - - Tensor a = tf.make({0}, {}); - ASSERT_EQ(a.numel(), 0); - bool loop_entered = false; - for (auto _ : BroadcastIndexesRange<1>(a, a)) { - loop_entered = true; - } - EXPECT_FALSE(loop_entered); -} - -// [W] -> [W] -TEST(BroadcastIndexesRangeTest, OneDNotBroadcasted) { - TensorFactory tf; - - Tensor out = tf.zeros({5}); - int idx = 0; - for (const auto& elem : range_to_vec(BroadcastIndexesRange<1>(out, out))) { - EXPECT_EQ(elem[0], idx++); - EXPECT_EQ(elem[0], elem[1]); - } -} - -// [1] -> [W] -TEST(BroadcastIndexesRangeTest, ScalarBroadcastToOneD) { - TensorFactory tf; - - Tensor out = tf.zeros({5}); - Tensor in = tf.zeros({1}); - - auto actual = range_to_vec(BroadcastIndexesRange<1>(out, in)); - decltype(actual) expected = { - {0, 0}, - {1, 0}, - {2, 0}, - {3, 0}, - {4, 0}, - }; - EXPECT_EQ(expected, actual); -} - -// [1] -> [H, W] -// [1, W] -> [H, W] -// [H, 1] -> [H, W] -// [H, W] -> [H, W] -// Cover all these at the same time to also exercise multiple input tensors. -TEST(BroadcastIndexesRangeTest, TwoDExhaustive) { - TensorFactory tf; - Tensor out = tf.zeros({3, 4}); - Tensor in_0d_scalar = tf.zeros({}); - Tensor in_1d_scalar = tf.zeros({1}); - Tensor in_2d_scalar = tf.zeros({1, 1}); - - Tensor in_row = tf.zeros({4}); - Tensor in_col = tf.zeros({3, 1}); - - Tensor in_not_broadcast = tf.zeros({3, 4}); - - auto actual = range_to_vec(BroadcastIndexesRange<6>( - out, - in_0d_scalar, - in_1d_scalar, - in_2d_scalar, - in_row, - in_col, - in_not_broadcast)); - decltype(actual) expected = { - {0, 0, 0, 0, 0, 0, 0}, - {1, 0, 0, 0, 1, 0, 1}, - {2, 0, 0, 0, 2, 0, 2}, - {3, 0, 0, 0, 3, 0, 3}, - {4, 0, 0, 0, 0, 1, 4}, - {5, 0, 0, 0, 1, 1, 5}, - {6, 0, 0, 0, 2, 1, 6}, - {7, 0, 0, 0, 3, 1, 7}, - {8, 0, 0, 0, 0, 2, 8}, - {9, 0, 0, 0, 1, 2, 9}, - {10, 0, 0, 0, 2, 2, 10}, - {11, 0, 0, 0, 3, 2, 11}, - }; - EXPECT_EQ(expected, actual); -} - -// Here we assume that the previous tests established that padding -// with leading 1s is working, and test: -// [C, H, 1] -> [C, H, W] -// [C, 1, W] -> [C, H, W] -// [C, 1, 1] -> [C, H, W] -// [1, H, 1] -> [C, H, W] -// [C, H, W] -> [C, H, W] -TEST(BroadcastIndexesRangeTest, ThreeDBroadcasting) { - TensorFactory tf; - Tensor out = tf.zeros({2, 3, 4}); - Tensor in_broadcast_w = tf.zeros({2, 3, 1}); - Tensor in_broadcast_h = tf.zeros({2, 1, 4}); - Tensor in_broadcast_hw = tf.zeros({2, 1, 1}); - Tensor in_broadcast_cw = tf.zeros({1, 3, 1}); - Tensor in_not_broadcast = tf.zeros({2, 3, 4}); - auto actual = range_to_vec(BroadcastIndexesRange<5>( - out, - in_broadcast_w, - in_broadcast_h, - in_broadcast_hw, - in_broadcast_cw, - in_not_broadcast)); - decltype(actual) expected = { - {0, 0, 0, 0, 0, 0}, {1, 0, 1, 0, 0, 1}, {2, 0, 2, 0, 0, 2}, - {3, 0, 3, 0, 0, 3}, {4, 1, 0, 0, 1, 4}, {5, 1, 1, 0, 1, 5}, - {6, 1, 2, 0, 1, 6}, {7, 1, 3, 0, 1, 7}, {8, 2, 0, 0, 2, 8}, - {9, 2, 1, 0, 2, 9}, {10, 2, 2, 0, 2, 10}, {11, 2, 3, 0, 2, 11}, - {12, 3, 4, 1, 0, 12}, {13, 3, 5, 1, 0, 13}, {14, 3, 6, 1, 0, 14}, - {15, 3, 7, 1, 0, 15}, {16, 4, 4, 1, 1, 16}, {17, 4, 5, 1, 1, 17}, - {18, 4, 6, 1, 1, 18}, {19, 4, 7, 1, 1, 19}, {20, 5, 4, 1, 2, 20}, - {21, 5, 5, 1, 2, 21}, {22, 5, 6, 1, 2, 22}, {23, 5, 7, 1, 2, 23}, - }; - EXPECT_EQ(expected, actual); -} - -// 4-D should generalize, but we will go ahead and test: -// [N, 1, H, 1] -> [N, C, H, W] -// [1, C, 1, W] -> [N, C, H, W] -TEST(BroadcastIndexesRangeTest, FourDBroadcasting) { - TensorFactory tf; - Tensor out = tf.zeros({2, 3, 4, 5}); - Tensor in_broadcast_cw = tf.zeros({2, 1, 4, 1}); - Tensor in_broadcast_nh = tf.zeros({1, 3, 1, 5}); - - int idx = 0; - // Writing out all the indices would be too cumbersome, so here we - // take the opportunity to mutation test against delinearize_index - // and linearize_access_indexes. - for (const auto [out_idx, in_cw_idx, in_nh_idx] : - BroadcastIndexesRange<2>(out, in_broadcast_cw, in_broadcast_nh)) { - EXPECT_EQ(out_idx, idx++); - size_t out_indexes[executorch::runtime::kTensorDimensionLimit]; - delinearize_index( - out_idx, out, out_indexes, executorch::runtime::kTensorDimensionLimit); - EXPECT_EQ( - in_cw_idx, - linearize_access_indexes(out_indexes, out.dim(), in_broadcast_cw)); - EXPECT_EQ( - in_nh_idx, - linearize_access_indexes(out_indexes, out.dim(), in_broadcast_nh)); - } -} diff --git a/kernels/portable/cpu/util/test/targets.bzl b/kernels/portable/cpu/util/test/targets.bzl index 178eb25a79b..28988b90dcc 100644 --- a/kernels/portable/cpu/util/test/targets.bzl +++ b/kernels/portable/cpu/util/test/targets.bzl @@ -12,17 +12,6 @@ def define_common_targets(): ], ) - runtime.cxx_test( - name = "broadcast_indexes_range_test", - srcs = ["broadcast_indexes_range_test.cpp"], - deps = [ - "//executorch/kernels/portable/cpu/util:broadcast_util", - "//executorch/kernels/portable/cpu/util:broadcast_indexes_range", - "//executorch/runtime/core/exec_aten:lib", - "//executorch/runtime/core/exec_aten/testing_util:tensor_util", - ], - ) - runtime.cxx_test( name = "reduce_test", srcs = ["reduce_test.cpp"], diff --git a/kernels/test/CMakeLists.txt b/kernels/test/CMakeLists.txt index 394ec241698..24adb8d9c80 100644 --- a/kernels/test/CMakeLists.txt +++ b/kernels/test/CMakeLists.txt @@ -275,7 +275,6 @@ set(_optimized_kernels_test_sources "op_native_layer_norm_test.cpp" "op_neg_test.cpp" "op_sub_test.cpp" - "op_where_test.cpp" "UnaryUfuncRealHBBF16ToFloatHBF16Test.cpp" ${CMAKE_CURRENT_BINARY_DIR}/include/optimized/executorch/kernels/test/supported_features.cpp ) diff --git a/test/utils/OSSTestConfig.json b/test/utils/OSSTestConfig.json index cc5e625f1e8..70cb2d2e44f 100644 --- a/test/utils/OSSTestConfig.json +++ b/test/utils/OSSTestConfig.json @@ -7,8 +7,7 @@ "op_fast_hadamard_transform_test.cpp" ], "additional_libs": [ - "custom_ops", - "dumb_fht" + "custom_ops" ] }, { @@ -62,7 +61,6 @@ { "directory": "kernels/portable/cpu/util/test", "sources": [ - "broadcast_indexes_range_test.cpp", "broadcast_test.cpp", "reduce_test.cpp" ], From 764977b0f933490a6d68379ccf29ac7ab6b8abcc Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Fri, 28 Feb 2025 17:16:10 -0800 Subject: [PATCH 09/11] Update [ghstack-poisoned] --- kernels/portable/cpu/op_argmax.cpp | 5 ++++- kernels/portable/cpu/op_argmin.cpp | 12 +++++++++++- kernels/test/op_argmax_test.cpp | 13 +++++++++++++ kernels/test/op_argmin_test.cpp | 13 +++++++++++++ 4 files changed, 41 insertions(+), 2 deletions(-) diff --git a/kernels/portable/cpu/op_argmax.cpp b/kernels/portable/cpu/op_argmax.cpp index 39ad0171d5d..a272d4405a8 100644 --- a/kernels/portable/cpu/op_argmax.cpp +++ b/kernels/portable/cpu/op_argmax.cpp @@ -50,7 +50,10 @@ Tensor& argmax_out( for (const auto out_ix : c10::irange(out.numel())) { std::tuple acc = reduce_over_dim( [](CTYPE v, long ix, CTYPE acc_val, long acc_ix) { - if (!std::isnan(acc_val) && (std::isnan(v) || v > acc_val)) { + // the below condition as written is equivalent to + // !isnan(accval) && (isnan(v) || v > acc_val). See + // argument in op_argmin.cpp. + if (!std::isnan(acc_val) && !(v <= acc_val)) { acc_val = v; acc_ix = ix; } diff --git a/kernels/portable/cpu/op_argmin.cpp b/kernels/portable/cpu/op_argmin.cpp index 8148efa6264..a0ee82d2612 100644 --- a/kernels/portable/cpu/op_argmin.cpp +++ b/kernels/portable/cpu/op_argmin.cpp @@ -50,7 +50,17 @@ Tensor& argmin_out( for (const auto out_ix : c10::irange(out.numel())) { std::tuple acc = reduce_over_dim( [](CTYPE v, long ix, CTYPE acc_val, long acc_ix) { - if (!std::isnan(acc_val) && (std::isnan(v) || v < acc_val)) { + // the below condition as written is equivalent to !isnan(accval) && + // (isnan(v) || v < acc_val). cases: + // - if neither acc_val nor v is NaN, !(v >= acc_val) is + // trivially equivalent to v < acc_val. + // - if acc_val is NaN, the whole thing is trivially false. + // - if acc_val is not NaN and v is NaN, then v >= acc_val + // - is false because all comparisons involving NaN are + // - false, so the result is true. The result is trivially + // - true for the above condition that uses isnan(v) as + // - well. + if (!std::isnan(acc_val) && !(v >= acc_val)) { acc_val = v; acc_ix = ix; } diff --git a/kernels/test/op_argmax_test.cpp b/kernels/test/op_argmax_test.cpp index 66c79cefff7..4d68dfe88be 100644 --- a/kernels/test/op_argmax_test.cpp +++ b/kernels/test/op_argmax_test.cpp @@ -90,3 +90,16 @@ TEST_F(OpArgmaxTest, SanityCheckNullDim) { EXPECT_TENSOR_EQ(out, expected); // clang-format on } + +TEST_F(OpArgmaxTest, FirstNaNWins) { + TensorFactory tf_float; + Tensor in = tf_float.make({4}, {1, NAN, -4, NAN}); + + TensorFactory tf_long; + Tensor out = tf_long.zeros({}); + Tensor expected = tf_long.make({}, {1}); + + Tensor ret = op_argmax_out(in, {}, false, out); + EXPECT_TENSOR_EQ(out, ret); + EXPECT_TENSOR_EQ(out, expected); +} diff --git a/kernels/test/op_argmin_test.cpp b/kernels/test/op_argmin_test.cpp index 250fe4f7e1e..a0b2699a28f 100644 --- a/kernels/test/op_argmin_test.cpp +++ b/kernels/test/op_argmin_test.cpp @@ -90,3 +90,16 @@ TEST_F(OpArgminTest, SanityCheckNullDim) { EXPECT_TENSOR_EQ(out, expected); // clang-format on } + +TEST_F(OpArgminTest, FirstNaNWins) { + TensorFactory tf_float; + Tensor in = tf_float.make({4}, {1, NAN, -4, NAN}); + + TensorFactory tf_long; + Tensor out = tf_long.zeros({}); + Tensor expected = tf_long.make({}, {1}); + + Tensor ret = op_argmin_out(in, {}, false, out); + EXPECT_TENSOR_EQ(out, ret); + EXPECT_TENSOR_EQ(out, expected); +} From 80589b0187221732b852c4e06f354ad5230adc6a Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 4 Mar 2025 09:47:48 -0800 Subject: [PATCH 10/11] Update [ghstack-poisoned] --- kernels/portable/cpu/op_argmax.cpp | 5 +---- kernels/portable/cpu/op_argmin.cpp | 12 +----------- kernels/test/op_argmax_test.cpp | 13 ------------- kernels/test/op_argmin_test.cpp | 13 ------------- 4 files changed, 2 insertions(+), 41 deletions(-) diff --git a/kernels/portable/cpu/op_argmax.cpp b/kernels/portable/cpu/op_argmax.cpp index a272d4405a8..39ad0171d5d 100644 --- a/kernels/portable/cpu/op_argmax.cpp +++ b/kernels/portable/cpu/op_argmax.cpp @@ -50,10 +50,7 @@ Tensor& argmax_out( for (const auto out_ix : c10::irange(out.numel())) { std::tuple acc = reduce_over_dim( [](CTYPE v, long ix, CTYPE acc_val, long acc_ix) { - // the below condition as written is equivalent to - // !isnan(accval) && (isnan(v) || v > acc_val). See - // argument in op_argmin.cpp. - if (!std::isnan(acc_val) && !(v <= acc_val)) { + if (!std::isnan(acc_val) && (std::isnan(v) || v > acc_val)) { acc_val = v; acc_ix = ix; } diff --git a/kernels/portable/cpu/op_argmin.cpp b/kernels/portable/cpu/op_argmin.cpp index a0ee82d2612..8148efa6264 100644 --- a/kernels/portable/cpu/op_argmin.cpp +++ b/kernels/portable/cpu/op_argmin.cpp @@ -50,17 +50,7 @@ Tensor& argmin_out( for (const auto out_ix : c10::irange(out.numel())) { std::tuple acc = reduce_over_dim( [](CTYPE v, long ix, CTYPE acc_val, long acc_ix) { - // the below condition as written is equivalent to !isnan(accval) && - // (isnan(v) || v < acc_val). cases: - // - if neither acc_val nor v is NaN, !(v >= acc_val) is - // trivially equivalent to v < acc_val. - // - if acc_val is NaN, the whole thing is trivially false. - // - if acc_val is not NaN and v is NaN, then v >= acc_val - // - is false because all comparisons involving NaN are - // - false, so the result is true. The result is trivially - // - true for the above condition that uses isnan(v) as - // - well. - if (!std::isnan(acc_val) && !(v >= acc_val)) { + if (!std::isnan(acc_val) && (std::isnan(v) || v < acc_val)) { acc_val = v; acc_ix = ix; } diff --git a/kernels/test/op_argmax_test.cpp b/kernels/test/op_argmax_test.cpp index 4d68dfe88be..66c79cefff7 100644 --- a/kernels/test/op_argmax_test.cpp +++ b/kernels/test/op_argmax_test.cpp @@ -90,16 +90,3 @@ TEST_F(OpArgmaxTest, SanityCheckNullDim) { EXPECT_TENSOR_EQ(out, expected); // clang-format on } - -TEST_F(OpArgmaxTest, FirstNaNWins) { - TensorFactory tf_float; - Tensor in = tf_float.make({4}, {1, NAN, -4, NAN}); - - TensorFactory tf_long; - Tensor out = tf_long.zeros({}); - Tensor expected = tf_long.make({}, {1}); - - Tensor ret = op_argmax_out(in, {}, false, out); - EXPECT_TENSOR_EQ(out, ret); - EXPECT_TENSOR_EQ(out, expected); -} diff --git a/kernels/test/op_argmin_test.cpp b/kernels/test/op_argmin_test.cpp index a0b2699a28f..250fe4f7e1e 100644 --- a/kernels/test/op_argmin_test.cpp +++ b/kernels/test/op_argmin_test.cpp @@ -90,16 +90,3 @@ TEST_F(OpArgminTest, SanityCheckNullDim) { EXPECT_TENSOR_EQ(out, expected); // clang-format on } - -TEST_F(OpArgminTest, FirstNaNWins) { - TensorFactory tf_float; - Tensor in = tf_float.make({4}, {1, NAN, -4, NAN}); - - TensorFactory tf_long; - Tensor out = tf_long.zeros({}); - Tensor expected = tf_long.make({}, {1}); - - Tensor ret = op_argmin_out(in, {}, false, out); - EXPECT_TENSOR_EQ(out, ret); - EXPECT_TENSOR_EQ(out, expected); -} From 18c64ad24755dc29e57c5d2942c9d51259ca5b01 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 4 Mar 2025 10:00:15 -0800 Subject: [PATCH 11/11] Update [ghstack-poisoned] --- backends/xnnpack/CMakeLists.txt | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/backends/xnnpack/CMakeLists.txt b/backends/xnnpack/CMakeLists.txt index 56b8ba96a05..02cd0b6d988 100644 --- a/backends/xnnpack/CMakeLists.txt +++ b/backends/xnnpack/CMakeLists.txt @@ -33,14 +33,14 @@ if(NOT PYTHON_EXECUTABLE) resolve_python_executable() endif() -# NB: Enabling this will serialize execution of delegate instances -# Keeping this OFF by default to maintain existing behavior, to be revisited. +# NB: Enabling this will serialize execution of delegate instances Keeping this +# OFF by default to maintain existing behavior, to be revisited. option(EXECUTORCH_XNNPACK_SHARED_WORKSPACE - "Enable workspace sharing across different delegate instances" ON) -# Keeping this OFF by default due to regressions in decode -# and model load with kleidi kernels -option(EXECUTORCH_XNNPACK_ENABLE_KLEIDI - "Enable Arm Kleidi kernels" OFF) + "Enable workspace sharing across different delegate instances" ON +) +# Keeping this OFF by default due to regressions in decode and model load with +# kleidi kernels +option(EXECUTORCH_XNNPACK_ENABLE_KLEIDI "Enable Arm Kleidi kernels" OFF) if(EXECUTORCH_XNNPACK_SHARED_WORKSPACE) add_definitions(-DENABLE_XNNPACK_SHARED_WORKSPACE) endif() @@ -100,8 +100,7 @@ include(cmake/Dependencies.cmake) list(TRANSFORM _xnnpack_backend__srcs PREPEND "${EXECUTORCH_ROOT}/") add_library(xnnpack_backend STATIC ${_xnnpack_backend__srcs}) target_link_libraries( - xnnpack_backend PRIVATE ${xnnpack_third_party} executorch_core - xnnpack_schema + xnnpack_backend PRIVATE ${xnnpack_third_party} executorch_core xnnpack_schema ) target_include_directories( @@ -119,6 +118,12 @@ target_include_directories( target_compile_options(xnnpack_backend PUBLIC ${_common_compile_options}) target_link_options_shared_lib(xnnpack_backend) +if(EXECUTORCH_BUILD_KERNELS_OPTIMIZED) + list(APPEND xnn_executor_runner_libs optimized_native_cpu_ops_lib) +else() + list(APPEND xnn_executor_runner_libs portable_ops_lib) +endif() + list(APPEND xnn_executor_runner_libs xnnpack_backend executorch) # ios can only build library but not binary @@ -134,13 +139,14 @@ if(NOT CMAKE_TOOLCHAIN_FILE MATCHES ".*(iOS|ios\.toolchain)\.cmake$") if(EXECUTORCH_BUILD_DEVTOOLS) list(APPEND xnn_executor_runner_libs etdump) else() - message(SEND_ERROR "Use of 'EXECUTORCH_ENABLE_EVENT_TRACER' requires 'EXECUTORCH_BUILD_DEVTOOLS' to be enabled.") + message( + SEND_ERROR + "Use of 'EXECUTORCH_ENABLE_EVENT_TRACER' requires 'EXECUTORCH_BUILD_DEVTOOLS' to be enabled." + ) endif() endif() - target_link_libraries( - xnn_executor_runner gflags optimized_native_cpu_ops_lib ${xnn_executor_runner_libs} - ) + target_link_libraries(xnn_executor_runner gflags ${xnn_executor_runner_libs}) target_compile_options(xnn_executor_runner PUBLIC ${_common_compile_options}) endif()