Skip to content

Commit f5927d3

Browse files
swolchokfacebook-github-bot
authored andcommitted
Use compile-time promotion to reduce optimized add/sub op size & build time
Summary: Yet another pair of ops. Differential Revision: D57023819
1 parent 032c173 commit f5927d3

File tree

2 files changed

+121
-34
lines changed

2 files changed

+121
-34
lines changed

kernels/optimized/cpu/op_add.cpp

+63-20
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,55 @@
1616
namespace torch {
1717
namespace executor {
1818
namespace native {
19+
namespace {
20+
21+
template <
22+
bool can_cast,
23+
typename CTYPE_A,
24+
typename CTYPE_B,
25+
typename CTYPE_IN,
26+
typename CTYPE_OUT>
27+
struct AddInner;
28+
29+
template <
30+
typename CTYPE_A,
31+
typename CTYPE_B,
32+
typename CTYPE_IN,
33+
typename CTYPE_OUT>
34+
struct AddInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
35+
static void
36+
run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) {
37+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
38+
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
39+
[alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
40+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
41+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
42+
CTYPE_IN value = a_casted + alpha_val * b_casted;
43+
44+
return static_cast<CTYPE_OUT>(value);
45+
},
46+
a,
47+
b,
48+
out);
49+
}
50+
};
51+
52+
template <typename CTYPE_IN>
53+
struct ReportCanCastBug {
54+
static void run(const Tensor&, const Tensor&, CTYPE_IN, Tensor&) {
55+
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
56+
}
57+
};
58+
59+
template <
60+
typename CTYPE_A,
61+
typename CTYPE_B,
62+
typename CTYPE_IN,
63+
typename CTYPE_OUT>
64+
struct AddInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
65+
: public ReportCanCastBug<CTYPE_IN> {};
66+
67+
} // namespace
1968

2069
using Tensor = exec_aten::Tensor;
2170
using ScalarType = exec_aten::ScalarType;
@@ -69,26 +118,20 @@ Tensor& opt_add_out(
69118

70119
ET_SWITCH_REALHB_TYPES(a_type, ctx, "add.out", CTYPE_A, [&]() {
71120
ET_SWITCH_REALHB_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() {
72-
ET_SWITCH_REALB_TYPES(common_type, ctx, "add.out", CTYPE_IN, [&]() {
73-
ET_SWITCH_REALHB_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() {
74-
CTYPE_IN alpha_val;
75-
ET_KERNEL_CHECK(
76-
ctx,
77-
utils::extract_scalar(alpha, &alpha_val),
78-
InvalidArgument, );
79-
80-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
81-
[alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
82-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
83-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
84-
CTYPE_IN value = a_casted + alpha_val * b_casted;
85-
86-
return static_cast<CTYPE_OUT>(value);
87-
},
88-
a,
89-
b,
90-
out);
91-
});
121+
using CTYPE_IN = typename torch::executor::
122+
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
123+
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
124+
ET_SWITCH_REALHB_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() {
125+
CTYPE_IN alpha_val;
126+
ET_KERNEL_CHECK(
127+
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
128+
129+
AddInner<
130+
can_cast<CTYPE_IN, CTYPE_OUT>::value,
131+
CTYPE_A,
132+
CTYPE_B,
133+
CTYPE_IN,
134+
CTYPE_OUT>::run(a, b, alpha_val, out);
92135
});
93136
});
94137
});

kernels/optimized/cpu/op_sub.cpp

+58-14
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,55 @@
1717
namespace torch {
1818
namespace executor {
1919
namespace native {
20+
namespace {
21+
22+
template <
23+
bool can_cast,
24+
typename CTYPE_A,
25+
typename CTYPE_B,
26+
typename CTYPE_IN,
27+
typename CTYPE_OUT>
28+
struct SubInner;
29+
30+
template <
31+
typename CTYPE_A,
32+
typename CTYPE_B,
33+
typename CTYPE_IN,
34+
typename CTYPE_OUT>
35+
struct SubInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
36+
static void
37+
run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) {
38+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
39+
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
40+
[alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
41+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
42+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
43+
CTYPE_IN value = a_casted - alpha_val * b_casted;
44+
45+
return static_cast<CTYPE_OUT>(value);
46+
},
47+
a,
48+
b,
49+
out);
50+
}
51+
};
52+
53+
template <typename CTYPE_IN>
54+
struct ReportCanCastBug {
55+
static void run(const Tensor&, const Tensor&, CTYPE_IN, Tensor&) {
56+
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
57+
}
58+
};
59+
60+
template <
61+
typename CTYPE_A,
62+
typename CTYPE_B,
63+
typename CTYPE_IN,
64+
typename CTYPE_OUT>
65+
struct SubInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
66+
: public ReportCanCastBug<CTYPE_IN> {};
67+
68+
} // namespace
2069

2170
using Tensor = exec_aten::Tensor;
2271
using ScalarType = exec_aten::ScalarType;
@@ -72,27 +121,22 @@ Tensor& opt_sub_out(
72121

73122
ET_SWITCH_REALH_TYPES(a_type, ctx, "sub.out", CTYPE_A, [&]() {
74123
ET_SWITCH_REALH_TYPES(b_type, ctx, "sub.out", CTYPE_B, [&]() {
75-
ET_SWITCH_REAL_TYPES(common_type, ctx, "sub.out", CTYPE_IN, [&]() {
124+
using CTYPE_IN = typename torch::executor::
125+
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
126+
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
76127
ET_SWITCH_REALH_TYPES(out_type, ctx, "sub.out", CTYPE_OUT, [&]() {
77128
CTYPE_IN alpha_val;
78129
ET_KERNEL_CHECK(
79130
ctx,
80131
utils::extract_scalar(alpha, &alpha_val),
81132
InvalidArgument, );
82-
83-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
84-
[alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
85-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
86-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
87-
CTYPE_IN value = a_casted - alpha_val * b_casted;
88-
89-
return static_cast<CTYPE_OUT>(value);
90-
},
91-
a,
92-
b,
93-
out);
133+
SubInner<
134+
can_cast<CTYPE_IN, CTYPE_OUT>::value,
135+
CTYPE_A,
136+
CTYPE_B,
137+
CTYPE_IN,
138+
CTYPE_OUT>::run(a, b, alpha_val, out);
94139
});
95-
});
96140
});
97141
});
98142
}

0 commit comments

Comments
 (0)