Skip to content

Commit 032c173

Browse files
swolchokfacebook-github-bot
authored andcommitted
Use compile-time promotion to reduce optimized mul op size & build time
Summary: another in a long line of fixes. Differential Revision: D56896048
1 parent 446ed37 commit 032c173

File tree

1 file changed

+59
-14
lines changed

1 file changed

+59
-14
lines changed

kernels/optimized/cpu/op_mul.cpp

+59-14
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,50 @@ bool can_use_optimized_path(
4141
(a.numel() == b.numel() && a.numel() == out.numel()));
4242
return can_use_optimized_path;
4343
}
44+
45+
template <
46+
bool can_cast,
47+
typename CTYPE_A,
48+
typename CTYPE_B,
49+
typename CTYPE_IN,
50+
typename CTYPE_OUT>
51+
struct MulInner;
52+
53+
template <
54+
typename CTYPE_A,
55+
typename CTYPE_B,
56+
typename CTYPE_IN,
57+
typename CTYPE_OUT>
58+
struct MulInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
59+
static void run(const Tensor& a, const Tensor& b, Tensor& out) {
60+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
61+
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
62+
[](const CTYPE_A val_a, const CTYPE_B val_b) {
63+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
64+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
65+
CTYPE_IN value = a_casted * b_casted;
66+
67+
return static_cast<CTYPE_OUT>(value);
68+
},
69+
a,
70+
b,
71+
out);
72+
}
73+
};
74+
75+
struct ReportCanCastBug {
76+
static void run(const Tensor&, const Tensor&, Tensor&) {
77+
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
78+
}
79+
};
80+
81+
template <
82+
typename CTYPE_A,
83+
typename CTYPE_B,
84+
typename CTYPE_IN,
85+
typename CTYPE_OUT>
86+
struct MulInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
87+
: public ReportCanCastBug {};
4488
} // namespace
4589

4690
Tensor& opt_mul_out(
@@ -86,20 +130,21 @@ Tensor& opt_mul_out(
86130

87131
ET_SWITCH_REALHB_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() {
88132
ET_SWITCH_REALHB_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() {
89-
ET_SWITCH_REALB_TYPES(common_type, ctx, "mul.out", CTYPE_IN, [&]() {
90-
ET_SWITCH_REALHB_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() {
91-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
92-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
93-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
94-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
95-
CTYPE_IN value = a_casted * b_casted;
96-
97-
return static_cast<CTYPE_OUT>(value);
98-
},
99-
a,
100-
b,
101-
out);
102-
});
133+
using CTYPE_IN = typename torch::executor::
134+
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
135+
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
136+
ET_SWITCH_REALHB_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() {
137+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
138+
[](const CTYPE_A val_a, const CTYPE_B val_b) {
139+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
140+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
141+
CTYPE_IN value = a_casted * b_casted;
142+
143+
return static_cast<CTYPE_OUT>(value);
144+
},
145+
a,
146+
b,
147+
out);
103148
});
104149
});
105150
});

0 commit comments

Comments
 (0)