@@ -20,6 +20,60 @@ namespace native {
20
20
using Tensor = exec_aten::Tensor;
21
21
using ScalarType = exec_aten::ScalarType;
22
22
23
+ namespace {
24
+ template <
25
+ bool can_cast,
26
+ typename CTYPE_A,
27
+ typename CTYPE_B,
28
+ typename CTYPE_IN,
29
+ typename CTYPE_OUT>
30
+ struct FloorDivideInner ;
31
+
32
+ template <
33
+ typename CTYPE_A,
34
+ typename CTYPE_B,
35
+ typename CTYPE_IN,
36
+ typename CTYPE_OUT>
37
+ struct FloorDivideInner <true , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
38
+ static void
39
+ run (const Tensor& a, const Tensor& b, Tensor& out, bool & div_by_zero_error) {
40
+ apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
41
+ // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
42
+ [&div_by_zero_error](const CTYPE_A val_a, const CTYPE_B val_b) {
43
+ if (is_integral_type<CTYPE_IN, /* includeBool=*/ true >::value) {
44
+ if (val_b == 0 ) {
45
+ div_by_zero_error = true ;
46
+ return static_cast <CTYPE_OUT>(0 );
47
+ }
48
+ }
49
+ CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
50
+ CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
51
+ CTYPE_IN value = utils::floor_divide<CTYPE_IN>(a_casted, b_casted);
52
+
53
+ return static_cast <CTYPE_OUT>(value);
54
+ },
55
+ a,
56
+ b,
57
+ out);
58
+ }
59
+ };
60
+
61
+ struct ReportCanCastBug {
62
+ static void run (const Tensor&, const Tensor&, Tensor&, bool &) {
63
+ ET_DCHECK_MSG (false , " BUG: canCast should have been checked above" );
64
+ }
65
+ };
66
+
67
+ template <
68
+ typename CTYPE_A,
69
+ typename CTYPE_B,
70
+ typename CTYPE_IN,
71
+ typename CTYPE_OUT>
72
+ struct FloorDivideInner <false , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
73
+ : public ReportCanCastBug {};
74
+
75
+ } // namespace
76
+
23
77
Tensor& floor_divide_out (
24
78
RuntimeContext& ctx,
25
79
const Tensor& a,
@@ -46,36 +100,17 @@ Tensor& floor_divide_out(
46
100
Bool, a_type, ctx, " floor_divide.out" , CTYPE_A, [&]() {
47
101
ET_SWITCH_REAL_TYPES_AND (
48
102
Bool, b_type, ctx, " floor_divide.out" , CTYPE_B, [&]() {
103
+ using CTYPE_IN = typename torch::executor::
104
+ promote_types<CTYPE_A, CTYPE_B>::type;
105
+ ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
49
106
ET_SWITCH_REAL_TYPES (
50
- common_type, ctx, " floor_divide.out" , CTYPE_IN, [&]() {
51
- ET_SWITCH_REAL_TYPES (
52
- out_type, ctx, " floor_divide.out" , CTYPE_OUT, [&]() {
53
- apply_binary_elementwise_fn<
54
- CTYPE_A,
55
- CTYPE_B,
56
- CTYPE_OUT>(
57
- [common_type, &div_by_zero_error](
58
- const CTYPE_A val_a, const CTYPE_B val_b) {
59
- if (isIntegralType (
60
- common_type, /* includeBool=*/ true )) {
61
- if (val_b == 0 ) {
62
- div_by_zero_error = true ;
63
- return static_cast <CTYPE_OUT>(0 );
64
- }
65
- }
66
- CTYPE_IN a_casted =
67
- static_cast <CTYPE_IN>(val_a);
68
- CTYPE_IN b_casted =
69
- static_cast <CTYPE_IN>(val_b);
70
- CTYPE_IN value = utils::floor_divide<CTYPE_IN>(
71
- a_casted, b_casted);
72
-
73
- return static_cast <CTYPE_OUT>(value);
74
- },
75
- a,
76
- b,
77
- out);
78
- });
107
+ out_type, ctx, " floor_divide.out" , CTYPE_OUT, [&]() {
108
+ FloorDivideInner<
109
+ can_cast<CTYPE_IN, CTYPE_OUT>::value,
110
+ CTYPE_A,
111
+ CTYPE_B,
112
+ CTYPE_IN,
113
+ CTYPE_OUT>::run (a, b, out, div_by_zero_error);
79
114
});
80
115
});
81
116
});
0 commit comments