Skip to content

Use compile-time promotion to reduce optimized mul op size & build time #3532

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
73 changes: 59 additions & 14 deletions kernels/optimized/cpu/op_mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,50 @@ bool can_use_optimized_path(
(a.numel() == b.numel() && a.numel() == out.numel()));
return can_use_optimized_path;
}

template <
bool can_cast,
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct MulInner;

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct MulInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
static void run(const Tensor& a, const Tensor& b, Tensor& out) {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = a_casted * b_casted;

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
}
};

struct ReportCanCastBug {
static void run(const Tensor&, const Tensor&, Tensor&) {
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
}
};

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct MulInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
: public ReportCanCastBug {};
} // namespace

Tensor& opt_mul_out(
Expand Down Expand Up @@ -86,20 +130,21 @@ Tensor& opt_mul_out(

ET_SWITCH_REALHB_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() {
ET_SWITCH_REALHB_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() {
ET_SWITCH_REALB_TYPES(common_type, ctx, "mul.out", CTYPE_IN, [&]() {
ET_SWITCH_REALHB_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = a_casted * b_casted;

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
});
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
ET_SWITCH_REALHB_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = a_casted * b_casted;

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
});
});
});
Expand Down
63 changes: 18 additions & 45 deletions kernels/portable/cpu/op_bitwise_and.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
* LICENSE file in the root directory of this source tree.
*/

#include <cmath>
// patternlint-disable-next-line executorch-cpp-nostdinc
#include <functional>

#include <executorch/kernels/portable/cpu/pattern/bitwise_op.h>
#include <executorch/kernels/portable/cpu/scalar_utils.h>
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/kernels/portable/cpu/util/functional_util.h>
Expand All @@ -17,20 +19,6 @@ namespace torch {
namespace executor {
namespace native {

namespace {

template <typename CTYPE>
CTYPE bitwise_and(CTYPE a, CTYPE b) {
return a & b;
}

template <>
bool bitwise_and<bool>(bool a, bool b) {
return a && b;
}

} // namespace

using Tensor = exec_aten::Tensor;

Tensor& bitwise_and_Tensor_out(
Expand All @@ -55,38 +43,23 @@ Tensor& bitwise_and_Tensor_out(
Bool, a_type, ctx, "bitwise_and.Tensor_out", CTYPE_A, [&]() {
ET_SWITCH_INT_TYPES_AND(
Bool, b_type, ctx, "bitwise_and.Tensor_out", CTYPE_B, [&]() {
ET_SWITCH_INT_TYPES_AND(
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
ET_SWITCH_REAL_TYPES_AND(
Bool,
common_type,
out_type,
ctx,
"bitwise_and.Tensor_out",
CTYPE_IN,
CTYPE_OUT,
[&]() {
ET_SWITCH_REAL_TYPES_AND(
Bool,
out_type,
ctx,
"bitwise_and.Tensor_out",
CTYPE_OUT,
[&]() {
apply_binary_elementwise_fn<
CTYPE_A,
CTYPE_B,
CTYPE_OUT>(
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted =
static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted =
static_cast<CTYPE_IN>(val_b);
CTYPE_IN value =
bitwise_and(a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
});
internal::BitwiseOpInner<
can_cast<CTYPE_IN, CTYPE_OUT>::value,
std::bit_and,
CTYPE_A,
CTYPE_B,
CTYPE_IN,
CTYPE_OUT>::run(a, b, out);
});
});
});
Expand Down Expand Up @@ -142,8 +115,8 @@ Tensor& bitwise_and_Scalar_out(
static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted =
static_cast<CTYPE_IN>(val_b);
CTYPE_IN value =
bitwise_and(a_casted, b_casted);
CTYPE_IN value = std::bit_and<CTYPE_IN>()(
a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
Expand Down
61 changes: 18 additions & 43 deletions kernels/portable/cpu/op_bitwise_or.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
* LICENSE file in the root directory of this source tree.
*/

#include <cmath>
// patternlint-disable-next-line executorch-cpp-nostdinc
#include <functional>

#include <executorch/kernels/portable/cpu/pattern/bitwise_op.h>
#include <executorch/kernels/portable/cpu/scalar_utils.h>
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/kernels/portable/cpu/util/functional_util.h>
Expand All @@ -17,20 +19,6 @@ namespace torch {
namespace executor {
namespace native {

namespace {

template <typename CTYPE>
CTYPE bitwise_or(CTYPE a, CTYPE b) {
return a | b;
}

template <>
bool bitwise_or<bool>(bool a, bool b) {
return a || b;
}

} // namespace

using Tensor = exec_aten::Tensor;

Tensor& bitwise_or_Tensor_out(
Expand All @@ -55,37 +43,23 @@ Tensor& bitwise_or_Tensor_out(
Bool, a_type, ctx, "bitwise_or.Tensor_out", CTYPE_A, [&]() {
ET_SWITCH_INT_TYPES_AND(
Bool, b_type, ctx, "bitwise_or.Tensor_out", CTYPE_B, [&]() {
ET_SWITCH_INT_TYPES_AND(
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
ET_SWITCH_REAL_TYPES_AND(
Bool,
common_type,
out_type,
ctx,
"bitwise_or.Tensor_out",
CTYPE_IN,
CTYPE_OUT,
[&]() {
ET_SWITCH_REAL_TYPES_AND(
Bool,
out_type,
ctx,
"bitwise_or.Tensor_out",
CTYPE_OUT,
[&]() {
apply_binary_elementwise_fn<
CTYPE_A,
CTYPE_B,
CTYPE_OUT>(
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted =
static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted =
static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = bitwise_or(a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
});
internal::BitwiseOpInner<
can_cast<CTYPE_IN, CTYPE_OUT>::value,
std::bit_or,
CTYPE_A,
CTYPE_B,
CTYPE_IN,
CTYPE_OUT>::run(a, b, out);
});
});
});
Expand Down Expand Up @@ -141,7 +115,8 @@ Tensor& bitwise_or_Scalar_out(
static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted =
static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = bitwise_or(a_casted, b_casted);
CTYPE_IN value =
std::bit_or<CTYPE_IN>()(a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
Expand Down
64 changes: 18 additions & 46 deletions kernels/portable/cpu/op_bitwise_xor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
* LICENSE file in the root directory of this source tree.
*/

#include <cmath>
// patternlint-disable-next-line executorch-cpp-nostdinc
#include <functional>

#include <executorch/kernels/portable/cpu/pattern/bitwise_op.h>
#include <executorch/kernels/portable/cpu/scalar_utils.h>
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/kernels/portable/cpu/util/functional_util.h>
Expand All @@ -17,28 +19,13 @@ namespace torch {
namespace executor {
namespace native {

namespace {

template <typename CTYPE>
CTYPE bitwise_xor(CTYPE a, CTYPE b) {
return a ^ b;
}

template <>
bool bitwise_xor<bool>(bool a, bool b) {
return a != b;
}

} // namespace

using Tensor = exec_aten::Tensor;

Tensor& bitwise_xor_Tensor_out(
RuntimeContext& ctx,
const Tensor& a,
const Tensor& b,
Tensor& out) {
// Determine output size and resize for dynamic shapes
ET_KERNEL_CHECK(
ctx,
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
Expand All @@ -56,38 +43,23 @@ Tensor& bitwise_xor_Tensor_out(
Bool, a_type, ctx, "bitwise_xor.Tensor_out", CTYPE_A, [&]() {
ET_SWITCH_INT_TYPES_AND(
Bool, b_type, ctx, "bitwise_xor.Tensor_out", CTYPE_B, [&]() {
ET_SWITCH_INT_TYPES_AND(
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
ET_SWITCH_REAL_TYPES_AND(
Bool,
common_type,
out_type,
ctx,
"bitwise_xor.Tensor_out",
CTYPE_IN,
CTYPE_OUT,
[&]() {
ET_SWITCH_REAL_TYPES_AND(
Bool,
out_type,
ctx,
"bitwise_xor.Tensor_out",
CTYPE_OUT,
[&]() {
apply_binary_elementwise_fn<
CTYPE_A,
CTYPE_B,
CTYPE_OUT>(
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted =
static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted =
static_cast<CTYPE_IN>(val_b);
CTYPE_IN value =
bitwise_xor(a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
});
internal::BitwiseOpInner<
can_cast<CTYPE_IN, CTYPE_OUT>::value,
std::bit_xor,
CTYPE_A,
CTYPE_B,
CTYPE_IN,
CTYPE_OUT>::run(a, b, out);
});
});
});
Expand Down Expand Up @@ -143,8 +115,8 @@ Tensor& bitwise_xor_Scalar_out(
static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted =
static_cast<CTYPE_IN>(val_b);
CTYPE_IN value =
bitwise_xor(a_casted, b_casted);
CTYPE_IN value = std::bit_xor<CTYPE_IN>()(
a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
Expand Down
Loading
Loading