diff --git a/kernels/portable/cpu/util/broadcast_indexes_range.h b/kernels/portable/cpu/util/broadcast_indexes_range.h index 7b78f4c2814..aaf7207d0c9 100644 --- a/kernels/portable/cpu/util/broadcast_indexes_range.h +++ b/kernels/portable/cpu/util/broadcast_indexes_range.h @@ -122,6 +122,11 @@ class BroadcastIndexesIterator { } output_index() += n; + if (output_dim_or_zero_if_no_broadcasting_ == 0) { + std::fill( + current_indexes_.begin() + 1, current_indexes_.end(), output_index()); + return *this; + } delinearize_index( output_index(), output_shape_, diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 23ec481bb7f..f5932069005 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -94,17 +95,28 @@ inline void apply_elementwise_fn( char* const data_out = reinterpret_cast(out.mutable_data_ptr()); const auto out_element_size = out.element_size(); - for (const auto& indexes : - BroadcastIndexesRange(out, (*inputs.first)...)) { - std::array loaded_inputs; - for (const auto idx : c10::irange(kNumInputs)) { - const auto& input_info = inputs_info[idx]; - loaded_inputs[idx] = input_info.load_to_common( - &input_info.data_ptr[indexes[idx + 1] * input_info.element_size]); - } - auto result = std::apply(compute_fun, loaded_inputs); - store_common_to_out(result, &data_out[indexes[0] * out_element_size]); - } + ::executorch::extension::parallel_for( + 0, + out.numel(), + ::executorch::extension::internal::GRAIN_SIZE, + [&](const auto begin, const auto end) { + const auto range = + BroadcastIndexesRange(out, (*inputs.first)...); + auto begin_it = range.begin(); + begin_it += begin; + for (; (*begin_it)[0] < end; ++begin_it) { + const auto& indexes = *begin_it; + std::array loaded_inputs; + for (const auto idx : c10::irange(kNumInputs)) { + const auto& input_info = inputs_info[idx]; + loaded_inputs[idx] = input_info.load_to_common( + &input_info + .data_ptr[indexes[idx + 1] * input_info.element_size]); + } + auto result = std::apply(compute_fun, loaded_inputs); + store_common_to_out(result, &data_out[indexes[0] * out_element_size]); + } + }); } } // namespace internal diff --git a/kernels/portable/cpu/util/functional_util.h b/kernels/portable/cpu/util/functional_util.h index 609a1a26fa5..d7ea201dbd2 100644 --- a/kernels/portable/cpu/util/functional_util.h +++ b/kernels/portable/cpu/util/functional_util.h @@ -12,6 +12,7 @@ #include #include +#include namespace torch { namespace executor { @@ -53,9 +54,15 @@ inline void apply_unary_map_fn( CTYPE_OUT* const data_out, const int64_t size, const int64_t stride = 1) { - for (const auto i : c10::irange(size)) { - data_out[i * stride] = map_fun(data_in[i * stride]); - } + executorch::extension::parallel_for( + 0, + size, + ::executorch::extension::internal::GRAIN_SIZE, + [&](const auto begin, const auto end) { + for (const auto i : c10::irange(begin, end)) { + data_out[i * stride] = map_fun(data_in[i * stride]); + } + }); } // diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index 95fd1734d8e..a623b9d4d7a 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -111,6 +111,7 @@ def define_common_targets(): ":broadcast_util", ":dtype_util", "//executorch/runtime/kernel:kernel_runtime_context", + "//executorch/runtime/kernel:thread_parallel_interface", ], deps = [ "//executorch/kernels/portable/cpu:scalar_utils", @@ -243,6 +244,9 @@ def define_common_targets(): name = "functional_util", srcs = [], exported_headers = ["functional_util.h"], + exported_deps = [ + "//executorch/runtime/kernel:thread_parallel_interface", + ], deps = [ "//executorch/runtime/kernel:kernel_includes", "//executorch/runtime/core/exec_aten/util:tensor_util", diff --git a/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp b/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp index 519cd9fe9f9..1023915ea66 100644 --- a/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp +++ b/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp @@ -44,7 +44,9 @@ TEST(BroadcastIndexesRangeTest, OneDNotBroadcasted) { Tensor out = tf.zeros({5}); int idx = 0; - for (const auto& elem : range_to_vec(BroadcastIndexesRange<1>(out, out))) { + const auto range = BroadcastIndexesRange<1>(out, out); + for (const auto& elem : range_to_vec(range)) { + EXPECT_EQ(*(range.begin() + idx), elem); EXPECT_EQ(elem[0], idx++); EXPECT_EQ(elem[0], elem[1]); } @@ -71,7 +73,7 @@ TEST(BroadcastIndexesRangeTest, ScalarBroadcastToOneD) { template void test_operator_plus(const Range& range) { size_t idx = 0; - for (const auto indexes : range) { + for (const auto& indexes : range) { EXPECT_EQ(*(range.begin() + idx), indexes); idx++; }