Skip to content

Commit 5bed826

Browse files
swolchokfacebook-github-bot
authored andcommitted
Use compile-time promotion to reduce remainder size & build time (#3458)
Summary: Pull Request resolved: #3458 Yet another op that can benefit from compile-time type promotion. Differential Revision: D56831293
1 parent 3e5ed1a commit 5bed826

File tree

2 files changed

+70
-25
lines changed

2 files changed

+70
-25
lines changed

kernels/portable/cpu/op_remainder.cpp

+56-25
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,52 @@ namespace native {
2020

2121
using Tensor = exec_aten::Tensor;
2222

23+
namespace {
24+
template <
25+
bool can_cast,
26+
typename CTYPE_A,
27+
typename CTYPE_B,
28+
typename CTYPE_IN,
29+
typename CTYPE_OUT>
30+
struct RemainderInner;
31+
32+
template <
33+
typename CTYPE_A,
34+
typename CTYPE_B,
35+
typename CTYPE_IN,
36+
typename CTYPE_OUT>
37+
struct RemainderInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
38+
static void run(const Tensor& a, const Tensor& b, Tensor& out) {
39+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
40+
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
41+
[](const CTYPE_A val_a, const CTYPE_B val_b) {
42+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
43+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
44+
CTYPE_IN value = utils::remainder_override(a_casted, b_casted);
45+
46+
return static_cast<CTYPE_OUT>(value);
47+
},
48+
a,
49+
b,
50+
out);
51+
}
52+
};
53+
54+
struct ReportCanCastBug {
55+
static void run(const Tensor&, const Tensor&, Tensor&) {
56+
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
57+
}
58+
};
59+
60+
template <
61+
typename CTYPE_A,
62+
typename CTYPE_B,
63+
typename CTYPE_IN,
64+
typename CTYPE_OUT>
65+
struct RemainderInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
66+
: public ReportCanCastBug {};
67+
68+
} // namespace
2369
Tensor& remainder_Tensor_out(
2470
RuntimeContext& ctx,
2571
const Tensor& a,
@@ -45,32 +91,17 @@ Tensor& remainder_Tensor_out(
4591
Bool, a_type, ctx, "remainder.Tensor_out", CTYPE_A, [&]() {
4692
ET_SWITCH_REAL_TYPES_AND(
4793
Bool, b_type, ctx, "remainder.Tensor_out", CTYPE_B, [&]() {
94+
using CTYPE_IN = typename torch::executor::
95+
promote_types<CTYPE_A, CTYPE_B>::type;
96+
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
4897
ET_SWITCH_REAL_TYPES(
49-
common_type, ctx, "remainder.Tensor_out", CTYPE_IN, [&]() {
50-
ET_SWITCH_REAL_TYPES(
51-
out_type,
52-
ctx,
53-
"remainder.Tensor_out",
54-
CTYPE_OUT,
55-
[&]() {
56-
apply_binary_elementwise_fn<
57-
CTYPE_A,
58-
CTYPE_B,
59-
CTYPE_OUT>(
60-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
61-
CTYPE_IN a_casted =
62-
static_cast<CTYPE_IN>(val_a);
63-
CTYPE_IN b_casted =
64-
static_cast<CTYPE_IN>(val_b);
65-
CTYPE_IN value = utils::remainder_override(
66-
a_casted, b_casted);
67-
68-
return static_cast<CTYPE_OUT>(value);
69-
},
70-
a,
71-
b,
72-
out);
73-
});
98+
out_type, ctx, "remainder.Tensor_out", CTYPE_OUT, [&]() {
99+
RemainderInner<
100+
can_cast<CTYPE_IN, CTYPE_OUT>::value,
101+
CTYPE_A,
102+
CTYPE_B,
103+
CTYPE_IN,
104+
CTYPE_OUT>::run(a, b, out);
74105
});
75106
});
76107
});

kernels/test/op_remainder_test.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ using exec_aten::Tensor;
2121
using torch::executor::testing::TensorFactory;
2222

2323
class OpRemainderOutTest : public OperatorTest {
24+
protected:
2425
Tensor& op_remainder_tensor_out(
2526
const Tensor& self,
2627
const Tensor& other,
@@ -35,3 +36,16 @@ class OpRemainderOutTest : public OperatorTest {
3536
return torch::executor::aten::remainder_outf(context_, self, other, out);
3637
}
3738
};
39+
40+
TEST_F(OpRemainderOutTest, SmokeTest) {
41+
TensorFactory<ScalarType::Long> tfDouble;
42+
TensorFactory<ScalarType::Long> tfLong;
43+
TensorFactory<ScalarType::Int> tfInt;
44+
45+
Tensor self = tfLong.full({2, 2}, 46);
46+
Tensor other = tfInt.full({2, 2}, 4);
47+
Tensor out = tfDouble.zeros({2, 2});
48+
Tensor out_expected = tfDouble.full({2, 2}, 2.0);
49+
op_remainder_tensor_out(self, other, out);
50+
EXPECT_TENSOR_CLOSE(out, out_expected);
51+
}

0 commit comments

Comments
 (0)