Skip to content

Commit 6ce8087

Browse files
swolchokfacebook-github-bot
authored andcommitted
Use compile-time promotion to reduce optimized le op size & build time (#3534)
Summary: Yet another optimized op. Differential Revision: D57028967
1 parent 59be8f9 commit 6ce8087

File tree

2 files changed

+20
-27
lines changed

2 files changed

+20
-27
lines changed

kernels/optimized/cpu/op_le.cpp

+15-20
Original file line numberDiff line numberDiff line change
@@ -53,31 +53,26 @@ Tensor& opt_le_tensor_out(
5353
a.numel());
5454
});
5555
} else {
56-
ScalarType common_type = promoteTypes(a_type, b_type);
5756
ET_SWITCH_REAL_TYPES_AND(
5857
Bool, a_type, ctx, "le.Tensor_out", CTYPE_A, [&]() {
5958
ET_SWITCH_REAL_TYPES_AND(
6059
Bool, b_type, ctx, "le.Tensor_out", CTYPE_B, [&]() {
60+
using CTYPE_IN = typename torch::executor::
61+
promote_types<CTYPE_A, CTYPE_B>::type;
62+
ET_DCHECK(
63+
CppTypeToScalarType<CTYPE_IN>::value ==
64+
promoteTypes(a_type, b_type));
6165
ET_SWITCH_REAL_TYPES_AND(
62-
Bool, common_type, ctx, "le.Tensor_out", CTYPE_IN, [&]() {
63-
ET_SWITCH_REAL_TYPES_AND(
64-
Bool,
65-
out_type,
66-
ctx,
67-
"le.Tensor_out",
68-
CTYPE_OUT,
69-
[&]() {
70-
const size_t n = a.numel();
71-
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
72-
const CTYPE_B* b_data = b.const_data_ptr<CTYPE_B>();
73-
CTYPE_OUT* out_data =
74-
out.mutable_data_ptr<CTYPE_OUT>();
75-
for (auto i = 0; i < n; ++i) {
76-
out_data[i] = static_cast<CTYPE_OUT>(
77-
static_cast<CTYPE_IN>(a_data[i]) <=
78-
static_cast<CTYPE_IN>(b_data[i]));
79-
}
80-
});
66+
Bool, out_type, ctx, "le.Tensor_out", CTYPE_OUT, [&]() {
67+
const size_t n = a.numel();
68+
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
69+
const CTYPE_B* b_data = b.const_data_ptr<CTYPE_B>();
70+
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
71+
for (auto i = 0; i < n; ++i) {
72+
out_data[i] = static_cast<CTYPE_OUT>(
73+
static_cast<CTYPE_IN>(a_data[i]) <=
74+
static_cast<CTYPE_IN>(b_data[i]));
75+
}
8176
});
8277
});
8378
});

kernels/optimized/cpu/op_sub.cpp

+5-7
Original file line numberDiff line numberDiff line change
@@ -124,19 +124,17 @@ Tensor& opt_sub_out(
124124
using CTYPE_IN = typename torch::executor::
125125
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
126126
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
127-
ET_SWITCH_REALH_TYPES(out_type, ctx, "sub.out", CTYPE_OUT, [&]() {
128-
CTYPE_IN alpha_val;
129-
ET_KERNEL_CHECK(
130-
ctx,
131-
utils::extract_scalar(alpha, &alpha_val),
132-
InvalidArgument, );
127+
ET_SWITCH_REALH_TYPES(out_type, ctx, "sub.out", CTYPE_OUT, [&]() {
128+
CTYPE_IN alpha_val;
129+
ET_KERNEL_CHECK(
130+
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
133131
SubInner<
134132
can_cast<CTYPE_IN, CTYPE_OUT>::value,
135133
CTYPE_A,
136134
CTYPE_B,
137135
CTYPE_IN,
138136
CTYPE_OUT>::run(a, b, alpha_val, out);
139-
});
137+
});
140138
});
141139
});
142140
}

0 commit comments

Comments
 (0)