@@ -53,31 +53,26 @@ Tensor& opt_le_tensor_out(
53
53
a.numel ());
54
54
});
55
55
} else {
56
- ScalarType common_type = promoteTypes (a_type, b_type);
57
56
ET_SWITCH_REAL_TYPES_AND (
58
57
Bool, a_type, ctx, " le.Tensor_out" , CTYPE_A, [&]() {
59
58
ET_SWITCH_REAL_TYPES_AND (
60
59
Bool, b_type, ctx, " le.Tensor_out" , CTYPE_B, [&]() {
60
+ using CTYPE_IN = typename torch::executor::
61
+ promote_types<CTYPE_A, CTYPE_B>::type;
62
+ ET_DCHECK (
63
+ CppTypeToScalarType<CTYPE_IN>::value ==
64
+ promoteTypes (a_type, b_type));
61
65
ET_SWITCH_REAL_TYPES_AND (
62
- Bool, common_type, ctx, " le.Tensor_out" , CTYPE_IN, [&]() {
63
- ET_SWITCH_REAL_TYPES_AND (
64
- Bool,
65
- out_type,
66
- ctx,
67
- " le.Tensor_out" ,
68
- CTYPE_OUT,
69
- [&]() {
70
- const size_t n = a.numel ();
71
- const CTYPE_A* a_data = a.const_data_ptr <CTYPE_A>();
72
- const CTYPE_B* b_data = b.const_data_ptr <CTYPE_B>();
73
- CTYPE_OUT* out_data =
74
- out.mutable_data_ptr <CTYPE_OUT>();
75
- for (auto i = 0 ; i < n; ++i) {
76
- out_data[i] = static_cast <CTYPE_OUT>(
77
- static_cast <CTYPE_IN>(a_data[i]) <=
78
- static_cast <CTYPE_IN>(b_data[i]));
79
- }
80
- });
66
+ Bool, out_type, ctx, " le.Tensor_out" , CTYPE_OUT, [&]() {
67
+ const size_t n = a.numel ();
68
+ const CTYPE_A* a_data = a.const_data_ptr <CTYPE_A>();
69
+ const CTYPE_B* b_data = b.const_data_ptr <CTYPE_B>();
70
+ CTYPE_OUT* out_data = out.mutable_data_ptr <CTYPE_OUT>();
71
+ for (auto i = 0 ; i < n; ++i) {
72
+ out_data[i] = static_cast <CTYPE_OUT>(
73
+ static_cast <CTYPE_IN>(a_data[i]) <=
74
+ static_cast <CTYPE_IN>(b_data[i]));
75
+ }
81
76
});
82
77
});
83
78
});
0 commit comments