diff --git a/kernels/portable/cpu/op_any.cpp b/kernels/portable/cpu/op_any.cpp index a9dd79ad34d..a368226db80 100644 --- a/kernels/portable/cpu/op_any.cpp +++ b/kernels/portable/cpu/op_any.cpp @@ -144,22 +144,26 @@ Tensor& any_out( ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] { ET_SWITCH_TWO_TYPES(Bool, Byte, out_type, ctx, name, CTYPE_OUT, [&] { CTYPE_OUT* out_data = out.mutable_data_ptr(); - for (const auto out_ix : c10::irange(out.numel())) { - CTYPE_OUT any = false; - if (in.numel() > 0) { - std::tuple acc = - map_reduce_over_dim( - [](CTYPE_IN v) { return static_cast(v); }, - [](bool outv, long, bool acc, long) { - return std::tuple{acc || outv, 0}; - }, - in, - dim, - out_ix); - any = std::get<0>(acc); - } - out_data[out_ix] = any; - } + const bool success = parallel_for_each_reduce_over_dim_output_index( + in, dim, out, [&](const auto begin, const auto end) { + for (const auto out_ix : c10::irange(begin, end)) { + CTYPE_OUT any = false; + if (in.numel() > 0) { + std::tuple acc = + map_reduce_over_dim( + [](CTYPE_IN v) { return static_cast(v); }, + [](bool outv, long, bool acc, long) { + return std::tuple{acc || outv, 0}; + }, + in, + dim, + out_ix); + any = std::get<0>(acc); + } + out_data[out_ix] = any; + } + }); + ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed"); }); }); diff --git a/kernels/portable/cpu/op_argmax.cpp b/kernels/portable/cpu/op_argmax.cpp index a272d4405a8..ffbc469c53d 100644 --- a/kernels/portable/cpu/op_argmax.cpp +++ b/kernels/portable/cpu/op_argmax.cpp @@ -47,23 +47,27 @@ Tensor& argmax_out( ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "argmax.out", CTYPE, [&] { long* out_data = out.mutable_data_ptr(); - for (const auto out_ix : c10::irange(out.numel())) { - std::tuple acc = reduce_over_dim( - [](CTYPE v, long ix, CTYPE acc_val, long acc_ix) { - // the below condition as written is equivalent to - // !isnan(accval) && (isnan(v) || v > acc_val). See - // argument in op_argmin.cpp. - if (!std::isnan(acc_val) && !(v <= acc_val)) { - acc_val = v; - acc_ix = ix; - } - return std::tuple{acc_val, acc_ix}; - }, - in, - dim, - out_ix); - out_data[out_ix] = std::get<1>(acc); - } + const bool success = parallel_for_each_reduce_over_dim_output_index( + in, dim, out, [&](const auto begin, const auto end) { + for (const auto out_ix : c10::irange(begin, end)) { + std::tuple acc = reduce_over_dim( + [](CTYPE v, long ix, CTYPE acc_val, long acc_ix) { + // the below condition as written is equivalent to + // !isnan(accval) && (isnan(v) || v > acc_val). See + // argument in op_argmin.cpp. + if (!std::isnan(acc_val) && !(v <= acc_val)) { + acc_val = v; + acc_ix = ix; + } + return std::tuple{acc_val, acc_ix}; + }, + in, + dim, + out_ix); + out_data[out_ix] = std::get<1>(acc); + } + }); + ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed"); }); return out; diff --git a/kernels/portable/cpu/op_max.cpp b/kernels/portable/cpu/op_max.cpp index f206ee05b99..3f4a1d27c0e 100644 --- a/kernels/portable/cpu/op_max.cpp +++ b/kernels/portable/cpu/op_max.cpp @@ -83,21 +83,26 @@ std::tuple max_out( CTYPE* max_data = max.mutable_data_ptr(); long* max_indices_data = max_indices.mutable_data_ptr(); - for (const auto out_ix : c10::irange(max.numel())) { - std::tuple acc = reduce_over_dim( - [](CTYPE v, long ix, CTYPE acc_val, long acc_ix) { - if (!std::isnan(acc_val) && (std::isnan(v) || v > acc_val)) { - acc_val = v; - acc_ix = ix; - } - return std::tuple{acc_val, acc_ix}; - }, - in, - dim, - out_ix); - max_data[out_ix] = std::get<0>(acc); - max_indices_data[out_ix] = std::get<1>(acc); - } + const bool success = parallel_for_each_reduce_over_dim_output_index( + in, dim, max, [&](const auto begin, const auto end) { + for (const auto out_ix : c10::irange(begin, end)) { + std::tuple acc = reduce_over_dim( + [](CTYPE v, long ix, CTYPE acc_val, long acc_ix) { + if (!std::isnan(acc_val) && + (std::isnan(v) || v > acc_val)) { + acc_val = v; + acc_ix = ix; + } + return std::tuple{acc_val, acc_ix}; + }, + in, + dim, + out_ix); + max_data[out_ix] = std::get<0>(acc); + max_indices_data[out_ix] = std::get<1>(acc); + } + }); + ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed"); }); return {max, max_indices}; diff --git a/kernels/portable/cpu/op_min.cpp b/kernels/portable/cpu/op_min.cpp index 683ef751a9d..8b70bcd40f5 100644 --- a/kernels/portable/cpu/op_min.cpp +++ b/kernels/portable/cpu/op_min.cpp @@ -83,21 +83,26 @@ std::tuple min_out( CTYPE* min_data = min.mutable_data_ptr(); long* min_indices_data = min_indices.mutable_data_ptr(); - for (const auto out_ix : c10::irange(min.numel())) { - std::tuple acc = reduce_over_dim( - [](CTYPE v, long ix, CTYPE acc_val, long acc_ix) { - if (!std::isnan(acc_val) && (std::isnan(v) || v < acc_val)) { - acc_val = v; - acc_ix = ix; - } - return std::tuple{acc_val, acc_ix}; - }, - in, - dim, - out_ix); - min_data[out_ix] = std::get<0>(acc); - min_indices_data[out_ix] = std::get<1>(acc); - } + const bool success = parallel_for_each_reduce_over_dim_output_index( + in, dim, min, [&](const auto begin, const auto end) { + for (const auto out_ix : c10::irange(begin, end)) { + std::tuple acc = reduce_over_dim( + [](CTYPE v, long ix, CTYPE acc_val, long acc_ix) { + if (!std::isnan(acc_val) && + (std::isnan(v) || v < acc_val)) { + acc_val = v; + acc_ix = ix; + } + return std::tuple{acc_val, acc_ix}; + }, + in, + dim, + out_ix); + min_data[out_ix] = std::get<0>(acc); + min_indices_data[out_ix] = std::get<1>(acc); + } + }); + ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed"); }); return {min, min_indices}; diff --git a/kernels/portable/cpu/op_prod.cpp b/kernels/portable/cpu/op_prod.cpp index 27d18ca2570..54580459d7c 100644 --- a/kernels/portable/cpu/op_prod.cpp +++ b/kernels/portable/cpu/op_prod.cpp @@ -77,22 +77,26 @@ Tensor& prod_int_out( ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] { ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] { CTYPE_OUT* out_data = out.mutable_data_ptr(); - for (const auto out_ix : c10::irange(out.numel())) { - CTYPE_OUT prod = 1; - if (in.numel() > 0) { - std::tuple acc = - map_reduce_over_dim( - [](CTYPE_IN v) { return static_cast(v); }, - [](CTYPE_OUT outv, long, CTYPE_OUT acc, long) { - return std::tuple{acc * outv, 0}; - }, - in, - dim, - out_ix); - prod = std::get<0>(acc); - } - out_data[out_ix] = prod; - } + const bool success = parallel_for_each_reduce_over_dim_output_index( + in, dim, out, [&](const auto begin, const auto end) { + for (const auto out_ix : c10::irange(begin, end)) { + CTYPE_OUT prod = 1; + if (in.numel() > 0) { + std::tuple acc = + map_reduce_over_dim( + [](CTYPE_IN v) { return static_cast(v); }, + [](CTYPE_OUT outv, long, CTYPE_OUT acc, long) { + return std::tuple{acc * outv, 0}; + }, + in, + dim, + out_ix); + prod = std::get<0>(acc); + } + out_data[out_ix] = prod; + } + }); + ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed"); }); }); diff --git a/kernels/portable/cpu/util/reduce_util.h b/kernels/portable/cpu/util/reduce_util.h index f9eeb6bebbc..1c6a6de4101 100644 --- a/kernels/portable/cpu/util/reduce_util.h +++ b/kernels/portable/cpu/util/reduce_util.h @@ -823,10 +823,30 @@ template executorch::aten::optional dim, const Tensor& out, const Func& func) { - const int64_t reduction_size = get_reduced_dim_product(in, dim); + const ssize_t reduction_size = get_reduced_dim_product(in, dim); const auto grain_size = std::max( - static_cast(1), - executorch::extension::internal::GRAIN_SIZE / reduction_size); + static_cast(1), + static_cast(executorch::extension::internal::GRAIN_SIZE) / + reduction_size); + return executorch::extension::parallel_for(0, out.numel(), grain_size, func); +} + +/** + * parallel_for wrapper for reductions that call reduce_over_dim_list or + * map_reduce_over_dim_list for each output element. Automatically + * calculates appropriate grain size. + */ +template +[[nodiscard]] bool parallel_for_each_reduce_over_dim_list_output_index( + const Tensor& in, + optional> dim_list, + const Tensor& out, + const Func& func) { + const ssize_t reduction_size = get_reduced_dim_product(in, dim_list); + const auto grain_size = std::max( + static_cast(1), + static_cast(executorch::extension::internal::GRAIN_SIZE) / + reduction_size); return executorch::extension::parallel_for(0, out.numel(), grain_size, func); }