Skip to content

Commit ffdc424

Browse files
swolchokfacebook-github-bot
authored andcommitted
Use compile-time promotion to reduce floor_divide size & build time (pytorch#3455)
Summary: Continuing rollout of this technique. Reviewed By: manuelcandales Differential Revision: D56827786
1 parent ae89a37 commit ffdc424

File tree

2 files changed

+70
-29
lines changed

2 files changed

+70
-29
lines changed

kernels/portable/cpu/op_floor_divide.cpp

+64-29
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,60 @@ namespace native {
2020
using Tensor = exec_aten::Tensor;
2121
using ScalarType = exec_aten::ScalarType;
2222

23+
namespace {
24+
template <
25+
bool can_cast,
26+
typename CTYPE_A,
27+
typename CTYPE_B,
28+
typename CTYPE_IN,
29+
typename CTYPE_OUT>
30+
struct FloorDivideInner;
31+
32+
template <
33+
typename CTYPE_A,
34+
typename CTYPE_B,
35+
typename CTYPE_IN,
36+
typename CTYPE_OUT>
37+
struct FloorDivideInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
38+
static void
39+
run(const Tensor& a, const Tensor& b, Tensor& out, bool& div_by_zero_error) {
40+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
41+
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
42+
[&div_by_zero_error](const CTYPE_A val_a, const CTYPE_B val_b) {
43+
if (is_integral_type<CTYPE_IN, /*includeBool=*/true>::value) {
44+
if (val_b == 0) {
45+
div_by_zero_error = true;
46+
return static_cast<CTYPE_OUT>(0);
47+
}
48+
}
49+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
50+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
51+
CTYPE_IN value = utils::floor_divide<CTYPE_IN>(a_casted, b_casted);
52+
53+
return static_cast<CTYPE_OUT>(value);
54+
},
55+
a,
56+
b,
57+
out);
58+
}
59+
};
60+
61+
struct ReportCanCastBug {
62+
static void run(const Tensor&, const Tensor&, Tensor&, bool&) {
63+
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
64+
}
65+
};
66+
67+
template <
68+
typename CTYPE_A,
69+
typename CTYPE_B,
70+
typename CTYPE_IN,
71+
typename CTYPE_OUT>
72+
struct FloorDivideInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
73+
: public ReportCanCastBug {};
74+
75+
} // namespace
76+
2377
Tensor& floor_divide_out(
2478
RuntimeContext& ctx,
2579
const Tensor& a,
@@ -46,36 +100,17 @@ Tensor& floor_divide_out(
46100
Bool, a_type, ctx, "floor_divide.out", CTYPE_A, [&]() {
47101
ET_SWITCH_REAL_TYPES_AND(
48102
Bool, b_type, ctx, "floor_divide.out", CTYPE_B, [&]() {
103+
using CTYPE_IN = typename torch::executor::
104+
promote_types<CTYPE_A, CTYPE_B>::type;
105+
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
49106
ET_SWITCH_REAL_TYPES(
50-
common_type, ctx, "floor_divide.out", CTYPE_IN, [&]() {
51-
ET_SWITCH_REAL_TYPES(
52-
out_type, ctx, "floor_divide.out", CTYPE_OUT, [&]() {
53-
apply_binary_elementwise_fn<
54-
CTYPE_A,
55-
CTYPE_B,
56-
CTYPE_OUT>(
57-
[common_type, &div_by_zero_error](
58-
const CTYPE_A val_a, const CTYPE_B val_b) {
59-
if (isIntegralType(
60-
common_type, /*includeBool=*/true)) {
61-
if (val_b == 0) {
62-
div_by_zero_error = true;
63-
return static_cast<CTYPE_OUT>(0);
64-
}
65-
}
66-
CTYPE_IN a_casted =
67-
static_cast<CTYPE_IN>(val_a);
68-
CTYPE_IN b_casted =
69-
static_cast<CTYPE_IN>(val_b);
70-
CTYPE_IN value = utils::floor_divide<CTYPE_IN>(
71-
a_casted, b_casted);
72-
73-
return static_cast<CTYPE_OUT>(value);
74-
},
75-
a,
76-
b,
77-
out);
78-
});
107+
out_type, ctx, "floor_divide.out", CTYPE_OUT, [&]() {
108+
FloorDivideInner<
109+
can_cast<CTYPE_IN, CTYPE_OUT>::value,
110+
CTYPE_A,
111+
CTYPE_B,
112+
CTYPE_IN,
113+
CTYPE_OUT>::run(a, b, out, div_by_zero_error);
79114
});
80115
});
81116
});

runtime/core/exec_aten/util/scalar_type_util.h

+6
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,12 @@ inline constexpr bool isIntegralType(
349349
t == exec_aten::ScalarType::Short);
350350
}
351351

352+
template <typename T, bool includeBool>
353+
struct is_integral_type
354+
: public std::integral_constant<
355+
bool,
356+
isIntegralType(CppTypeToScalarType<T>::value, includeBool)> {};
357+
352358
inline constexpr bool isFloatingType(exec_aten::ScalarType t) {
353359
return (
354360
t == exec_aten::ScalarType::Double || t == exec_aten::ScalarType::Float ||

0 commit comments

Comments
 (0)