@@ -41,6 +41,50 @@ bool can_use_optimized_path(
41
41
(a.numel () == b.numel () && a.numel () == out.numel ()));
42
42
return can_use_optimized_path;
43
43
}
44
+
45
+ template <
46
+ bool can_cast,
47
+ typename CTYPE_A,
48
+ typename CTYPE_B,
49
+ typename CTYPE_IN,
50
+ typename CTYPE_OUT>
51
+ struct MulInner ;
52
+
53
+ template <
54
+ typename CTYPE_A,
55
+ typename CTYPE_B,
56
+ typename CTYPE_IN,
57
+ typename CTYPE_OUT>
58
+ struct MulInner <true , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
59
+ static void run (const Tensor& a, const Tensor& b, Tensor& out) {
60
+ apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
61
+ // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
62
+ [](const CTYPE_A val_a, const CTYPE_B val_b) {
63
+ CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
64
+ CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
65
+ CTYPE_IN value = a_casted * b_casted;
66
+
67
+ return static_cast <CTYPE_OUT>(value);
68
+ },
69
+ a,
70
+ b,
71
+ out);
72
+ }
73
+ };
74
+
75
+ struct ReportCanCastBug {
76
+ static void run (const Tensor&, const Tensor&, Tensor&) {
77
+ ET_DCHECK_MSG (false , " BUG: canCast should have been checked above" );
78
+ }
79
+ };
80
+
81
+ template <
82
+ typename CTYPE_A,
83
+ typename CTYPE_B,
84
+ typename CTYPE_IN,
85
+ typename CTYPE_OUT>
86
+ struct MulInner <false , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
87
+ : public ReportCanCastBug {};
44
88
} // namespace
45
89
46
90
Tensor& opt_mul_out (
@@ -86,20 +130,21 @@ Tensor& opt_mul_out(
86
130
87
131
ET_SWITCH_REALHB_TYPES (a_type, ctx, " mul.out" , CTYPE_A, [&]() {
88
132
ET_SWITCH_REALHB_TYPES (b_type, ctx, " mul.out" , CTYPE_B, [&]() {
89
- ET_SWITCH_REALB_TYPES (common_type, ctx, " mul.out" , CTYPE_IN, [&]() {
90
- ET_SWITCH_REALHB_TYPES (out_type, ctx, " mul.out" , CTYPE_OUT, [&]() {
91
- apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
92
- [](const CTYPE_A val_a, const CTYPE_B val_b) {
93
- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
94
- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
95
- CTYPE_IN value = a_casted * b_casted;
96
-
97
- return static_cast <CTYPE_OUT>(value);
98
- },
99
- a,
100
- b,
101
- out);
102
- });
133
+ using CTYPE_IN = typename torch::executor::
134
+ promote_types<CTYPE_A, CTYPE_B, /* half_to_float*/ true >::type;
135
+ ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
136
+ ET_SWITCH_REALHB_TYPES (out_type, ctx, " mul.out" , CTYPE_OUT, [&]() {
137
+ apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
138
+ [](const CTYPE_A val_a, const CTYPE_B val_b) {
139
+ CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
140
+ CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
141
+ CTYPE_IN value = a_casted * b_casted;
142
+
143
+ return static_cast <CTYPE_OUT>(value);
144
+ },
145
+ a,
146
+ b,
147
+ out);
103
148
});
104
149
});
105
150
});
0 commit comments