diff --git a/kernels/portable/cpu/op_argmin.cpp b/kernels/portable/cpu/op_argmin.cpp index 87e90de4c04..b0816596e4e 100644 --- a/kernels/portable/cpu/op_argmin.cpp +++ b/kernels/portable/cpu/op_argmin.cpp @@ -12,7 +12,6 @@ #include #include -#include #include namespace torch { @@ -48,17 +47,8 @@ Tensor& argmin_out( ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "argmin.out", CTYPE, [&] { long* out_data = out.mutable_data_ptr(); - // REVIEW: this is the parallelization strategy ATen uses - // specifically when the reduction is along the last dimension and - // that dimension is contiguous. Is there any particular reason we - // shouldn't just always use this strategy since we aren't - // otherwise capable of parallelizing reductions? - const int64_t reduction_size = get_reduced_dim_product(in, dim); - const auto grain_size = std::max( - static_cast(1), - executorch::extension::internal::GRAIN_SIZE / reduction_size); - const bool success = executorch::extension::parallel_for( - 0, out.numel(), grain_size, [&](const auto begin, const auto end) { + const bool success = parallel_for_each_reduce_over_dim_output_index( + in, dim, out, [&](const auto begin, const auto end) { for (const auto out_ix : c10::irange(begin, end)) { std::tuple acc = reduce_over_dim( [](CTYPE v, long ix, CTYPE acc_val, long acc_ix) { diff --git a/kernels/portable/cpu/util/reduce_util.h b/kernels/portable/cpu/util/reduce_util.h index 2160d9810ae..f9eeb6bebbc 100644 --- a/kernels/portable/cpu/util/reduce_util.h +++ b/kernels/portable/cpu/util/reduce_util.h @@ -10,6 +10,7 @@ #include #include +#include #include #include @@ -811,5 +812,23 @@ bool check_prod_out_args( #endif +/** + * parallel_for wrapper for reductions that call reduce_over_dim or + * map_reduce_over_dim for each output element. Automatically + * calculates appropriate grain size. + */ +template +[[nodiscard]] bool parallel_for_each_reduce_over_dim_output_index( + const Tensor& in, + executorch::aten::optional dim, + const Tensor& out, + const Func& func) { + const int64_t reduction_size = get_reduced_dim_product(in, dim); + const auto grain_size = std::max( + static_cast(1), + executorch::extension::internal::GRAIN_SIZE / reduction_size); + return executorch::extension::parallel_for(0, out.numel(), grain_size, func); +} + } // namespace executor } // namespace torch diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index db8202a920a..95fd1734d8e 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -314,6 +314,9 @@ def define_common_targets(): "//executorch/runtime/kernel:kernel_includes{}".format(suffix), "//executorch/runtime/core/exec_aten/util:tensor_util{}".format(suffix), ], + exported_deps = [ + "//executorch/runtime/kernel:thread_parallel_interface", + ], exported_preprocessor_flags = ["-DUSE_ATEN_LIB"] if aten_mode else [], visibility = [ "//executorch/extension/llm/custom_ops/...", diff --git a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl index dd48da64c30..b56413b92f4 100644 --- a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -284,7 +284,6 @@ ATEN_OPS = ( name = "op_argmin", deps = [ "//executorch/kernels/portable/cpu/util:reduce_util", - "//executorch/runtime/kernel:thread_parallel_interface", ], ), op_target(