Skip to content

Commit 41a85d2

Browse files
swolchokfacebook-github-bot
authored andcommitted
Use compile-time promotion to reduce max/min size & build time (#3459)
Summary: Pull Request resolved: #3459 Yet another smaller pair of ops. Differential Revision: D56807402
1 parent 8e35e24 commit 41a85d2

File tree

2 files changed

+108
-29
lines changed

2 files changed

+108
-29
lines changed

kernels/portable/cpu/op_maximum.cpp

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,50 @@ const T& max(const T& a, const T& b) {
2020
return (b > a) ? b : a;
2121
}
2222

23+
template <
24+
bool can_cast,
25+
typename CTYPE_A,
26+
typename CTYPE_B,
27+
typename CTYPE_IN,
28+
typename CTYPE_OUT>
29+
struct MaximumInner;
30+
31+
template <
32+
typename CTYPE_A,
33+
typename CTYPE_B,
34+
typename CTYPE_IN,
35+
typename CTYPE_OUT>
36+
struct MaximumInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
37+
static void run(const Tensor& a, const Tensor& b, Tensor& out) {
38+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
39+
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
40+
[](const CTYPE_A val_a, const CTYPE_B val_b) {
41+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
42+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
43+
CTYPE_IN value = max(a_casted, b_casted);
44+
45+
return static_cast<CTYPE_OUT>(value);
46+
},
47+
a,
48+
b,
49+
out);
50+
}
51+
};
52+
53+
struct ReportCanCastBug {
54+
static void run(const Tensor&, const Tensor&, Tensor&) {
55+
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
56+
}
57+
};
58+
59+
template <
60+
typename CTYPE_A,
61+
typename CTYPE_B,
62+
typename CTYPE_IN,
63+
typename CTYPE_OUT>
64+
struct MaximumInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
65+
: public ReportCanCastBug {};
66+
2367
} // namespace
2468

2569
Tensor& maximum_out(
@@ -44,20 +88,16 @@ Tensor& maximum_out(
4488

4589
ET_SWITCH_REALHB_TYPES(a_type, ctx, "maximum.out", CTYPE_A, [&]() {
4690
ET_SWITCH_REALHB_TYPES(b_type, ctx, "maximum.out", CTYPE_B, [&]() {
47-
ET_SWITCH_REALB_TYPES(common_type, ctx, "maximum.out", CTYPE_IN, [&]() {
48-
ET_SWITCH_REALHB_TYPES(out_type, ctx, "maximum.out", CTYPE_OUT, [&]() {
49-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
50-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
51-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
52-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
53-
CTYPE_IN value = max(a_casted, b_casted);
54-
55-
return static_cast<CTYPE_OUT>(value);
56-
},
57-
a,
58-
b,
59-
out);
60-
});
91+
using CTYPE_IN = typename torch::executor::
92+
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
93+
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
94+
ET_SWITCH_REALHB_TYPES(out_type, ctx, "maximum.out", CTYPE_OUT, [&]() {
95+
MaximumInner<
96+
can_cast<CTYPE_IN, CTYPE_OUT>::value,
97+
CTYPE_A,
98+
CTYPE_B,
99+
CTYPE_IN,
100+
CTYPE_OUT>::run(a, b, out);
61101
});
62102
});
63103
});

kernels/portable/cpu/op_minimum.cpp

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,50 @@ const T& min(const T& a, const T& b) {
2020
return (b < a) ? b : a;
2121
}
2222

23+
template <
24+
bool can_cast,
25+
typename CTYPE_A,
26+
typename CTYPE_B,
27+
typename CTYPE_IN,
28+
typename CTYPE_OUT>
29+
struct MinimumInner;
30+
31+
template <
32+
typename CTYPE_A,
33+
typename CTYPE_B,
34+
typename CTYPE_IN,
35+
typename CTYPE_OUT>
36+
struct MinimumInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
37+
static void run(const Tensor& a, const Tensor& b, Tensor& out) {
38+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
39+
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
40+
[](const CTYPE_A val_a, const CTYPE_B val_b) {
41+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
42+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
43+
CTYPE_IN value = min(a_casted, b_casted);
44+
45+
return static_cast<CTYPE_OUT>(value);
46+
},
47+
a,
48+
b,
49+
out);
50+
}
51+
};
52+
53+
struct ReportCanCastBug {
54+
static void run(const Tensor&, const Tensor&, Tensor&) {
55+
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
56+
}
57+
};
58+
59+
template <
60+
typename CTYPE_A,
61+
typename CTYPE_B,
62+
typename CTYPE_IN,
63+
typename CTYPE_OUT>
64+
struct MinimumInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
65+
: public ReportCanCastBug {};
66+
2367
} // namespace
2468

2569
Tensor& minimum_out(
@@ -44,22 +88,17 @@ Tensor& minimum_out(
4488

4589
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "minimum.out", CTYPE_A, [&]() {
4690
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "minimum.out", CTYPE_B, [&]() {
91+
using CTYPE_IN =
92+
typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
93+
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
4794
ET_SWITCH_REAL_TYPES_AND(
48-
Bool, common_type, ctx, "minimum.out", CTYPE_IN, [&]() {
49-
ET_SWITCH_REAL_TYPES_AND(
50-
Bool, out_type, ctx, "minimum.out", CTYPE_OUT, [&]() {
51-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
52-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
53-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
54-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
55-
CTYPE_IN value = min(a_casted, b_casted);
56-
57-
return static_cast<CTYPE_OUT>(value);
58-
},
59-
a,
60-
b,
61-
out);
62-
});
95+
Bool, out_type, ctx, "minimum.out", CTYPE_OUT, [&]() {
96+
MinimumInner<
97+
can_cast<CTYPE_IN, CTYPE_OUT>::value,
98+
CTYPE_A,
99+
CTYPE_B,
100+
CTYPE_IN,
101+
CTYPE_OUT>::run(a, b, out);
63102
});
64103
});
65104
});

0 commit comments

Comments
 (0)