@@ -20,6 +20,50 @@ const T& min(const T& a, const T& b) {
20
20
return (b < a) ? b : a;
21
21
}
22
22
23
+ template <
24
+ bool can_cast,
25
+ typename CTYPE_A,
26
+ typename CTYPE_B,
27
+ typename CTYPE_IN,
28
+ typename CTYPE_OUT>
29
+ struct MinimumInner ;
30
+
31
+ template <
32
+ typename CTYPE_A,
33
+ typename CTYPE_B,
34
+ typename CTYPE_IN,
35
+ typename CTYPE_OUT>
36
+ struct MinimumInner <true , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
37
+ static void run (const Tensor& a, const Tensor& b, Tensor& out) {
38
+ apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
39
+ // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
40
+ [](const CTYPE_A val_a, const CTYPE_B val_b) {
41
+ CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
42
+ CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
43
+ CTYPE_IN value = min (a_casted, b_casted);
44
+
45
+ return static_cast <CTYPE_OUT>(value);
46
+ },
47
+ a,
48
+ b,
49
+ out);
50
+ }
51
+ };
52
+
53
+ struct ReportCanCastBug {
54
+ static void run (const Tensor&, const Tensor&, Tensor&) {
55
+ ET_DCHECK_MSG (false , " BUG: canCast should have been checked above" );
56
+ }
57
+ };
58
+
59
+ template <
60
+ typename CTYPE_A,
61
+ typename CTYPE_B,
62
+ typename CTYPE_IN,
63
+ typename CTYPE_OUT>
64
+ struct MinimumInner <false , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
65
+ : public ReportCanCastBug {};
66
+
23
67
} // namespace
24
68
25
69
Tensor& minimum_out (
@@ -44,22 +88,17 @@ Tensor& minimum_out(
44
88
45
89
ET_SWITCH_REAL_TYPES_AND (Bool, a_type, ctx, " minimum.out" , CTYPE_A, [&]() {
46
90
ET_SWITCH_REAL_TYPES_AND (Bool, b_type, ctx, " minimum.out" , CTYPE_B, [&]() {
91
+ using CTYPE_IN =
92
+ typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
93
+ ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
47
94
ET_SWITCH_REAL_TYPES_AND (
48
- Bool, common_type, ctx, " minimum.out" , CTYPE_IN, [&]() {
49
- ET_SWITCH_REAL_TYPES_AND (
50
- Bool, out_type, ctx, " minimum.out" , CTYPE_OUT, [&]() {
51
- apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
52
- [](const CTYPE_A val_a, const CTYPE_B val_b) {
53
- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
54
- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
55
- CTYPE_IN value = min (a_casted, b_casted);
56
-
57
- return static_cast <CTYPE_OUT>(value);
58
- },
59
- a,
60
- b,
61
- out);
62
- });
95
+ Bool, out_type, ctx, " minimum.out" , CTYPE_OUT, [&]() {
96
+ MinimumInner<
97
+ can_cast<CTYPE_IN, CTYPE_OUT>::value,
98
+ CTYPE_A,
99
+ CTYPE_B,
100
+ CTYPE_IN,
101
+ CTYPE_OUT>::run (a, b, out);
63
102
});
64
103
});
65
104
});
0 commit comments