@@ -20,6 +20,52 @@ namespace native {
20
20
21
21
using Tensor = exec_aten::Tensor;
22
22
23
+ namespace {
24
+ template <
25
+ bool can_cast,
26
+ typename CTYPE_A,
27
+ typename CTYPE_B,
28
+ typename CTYPE_IN,
29
+ typename CTYPE_OUT>
30
+ struct RemainderInner ;
31
+
32
+ template <
33
+ typename CTYPE_A,
34
+ typename CTYPE_B,
35
+ typename CTYPE_IN,
36
+ typename CTYPE_OUT>
37
+ struct RemainderInner <true , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
38
+ static void run (const Tensor& a, const Tensor& b, Tensor& out) {
39
+ apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
40
+ // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
41
+ [](const CTYPE_A val_a, const CTYPE_B val_b) {
42
+ CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
43
+ CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
44
+ CTYPE_IN value = utils::remainder_override (a_casted, b_casted);
45
+
46
+ return static_cast <CTYPE_OUT>(value);
47
+ },
48
+ a,
49
+ b,
50
+ out);
51
+ }
52
+ };
53
+
54
+ struct ReportCanCastBug {
55
+ static void run (const Tensor&, const Tensor&, Tensor&) {
56
+ ET_DCHECK_MSG (false , " BUG: canCast should have been checked above" );
57
+ }
58
+ };
59
+
60
+ template <
61
+ typename CTYPE_A,
62
+ typename CTYPE_B,
63
+ typename CTYPE_IN,
64
+ typename CTYPE_OUT>
65
+ struct RemainderInner <false , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
66
+ : public ReportCanCastBug {};
67
+
68
+ } // namespace
23
69
Tensor& remainder_Tensor_out (
24
70
RuntimeContext& ctx,
25
71
const Tensor& a,
@@ -45,32 +91,17 @@ Tensor& remainder_Tensor_out(
45
91
Bool, a_type, ctx, " remainder.Tensor_out" , CTYPE_A, [&]() {
46
92
ET_SWITCH_REAL_TYPES_AND (
47
93
Bool, b_type, ctx, " remainder.Tensor_out" , CTYPE_B, [&]() {
94
+ using CTYPE_IN = typename torch::executor::
95
+ promote_types<CTYPE_A, CTYPE_B>::type;
96
+ ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
48
97
ET_SWITCH_REAL_TYPES (
49
- common_type, ctx, " remainder.Tensor_out" , CTYPE_IN, [&]() {
50
- ET_SWITCH_REAL_TYPES (
51
- out_type,
52
- ctx,
53
- " remainder.Tensor_out" ,
54
- CTYPE_OUT,
55
- [&]() {
56
- apply_binary_elementwise_fn<
57
- CTYPE_A,
58
- CTYPE_B,
59
- CTYPE_OUT>(
60
- [](const CTYPE_A val_a, const CTYPE_B val_b) {
61
- CTYPE_IN a_casted =
62
- static_cast <CTYPE_IN>(val_a);
63
- CTYPE_IN b_casted =
64
- static_cast <CTYPE_IN>(val_b);
65
- CTYPE_IN value = utils::remainder_override (
66
- a_casted, b_casted);
67
-
68
- return static_cast <CTYPE_OUT>(value);
69
- },
70
- a,
71
- b,
72
- out);
73
- });
98
+ out_type, ctx, " remainder.Tensor_out" , CTYPE_OUT, [&]() {
99
+ RemainderInner<
100
+ can_cast<CTYPE_IN, CTYPE_OUT>::value,
101
+ CTYPE_A,
102
+ CTYPE_B,
103
+ CTYPE_IN,
104
+ CTYPE_OUT>::run (a, b, out);
74
105
});
75
106
});
76
107
});
0 commit comments