|
16 | 16 | namespace torch {
|
17 | 17 | namespace executor {
|
18 | 18 | namespace native {
|
| 19 | +namespace { |
| 20 | + |
| 21 | +template < |
| 22 | + bool can_cast, |
| 23 | + typename CTYPE_A, |
| 24 | + typename CTYPE_B, |
| 25 | + typename CTYPE_IN, |
| 26 | + typename CTYPE_OUT> |
| 27 | +struct AddInner; |
| 28 | + |
| 29 | +template < |
| 30 | + typename CTYPE_A, |
| 31 | + typename CTYPE_B, |
| 32 | + typename CTYPE_IN, |
| 33 | + typename CTYPE_OUT> |
| 34 | +struct AddInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> { |
| 35 | + static void |
| 36 | + run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) { |
| 37 | + apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>( |
| 38 | + // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) |
| 39 | + [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) { |
| 40 | + CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a); |
| 41 | + CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b); |
| 42 | + CTYPE_IN value = a_casted + alpha_val * b_casted; |
| 43 | + |
| 44 | + return static_cast<CTYPE_OUT>(value); |
| 45 | + }, |
| 46 | + a, |
| 47 | + b, |
| 48 | + out); |
| 49 | + } |
| 50 | +}; |
| 51 | + |
| 52 | +template <typename CTYPE_IN> |
| 53 | +struct ReportCanCastBug { |
| 54 | + static void run(const Tensor&, const Tensor&, CTYPE_IN, Tensor&) { |
| 55 | + ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); |
| 56 | + } |
| 57 | +}; |
| 58 | + |
| 59 | +template < |
| 60 | + typename CTYPE_A, |
| 61 | + typename CTYPE_B, |
| 62 | + typename CTYPE_IN, |
| 63 | + typename CTYPE_OUT> |
| 64 | +struct AddInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> |
| 65 | + : public ReportCanCastBug<CTYPE_IN> {}; |
| 66 | + |
| 67 | +} // namespace |
19 | 68 |
|
20 | 69 | using Tensor = exec_aten::Tensor;
|
21 | 70 | using ScalarType = exec_aten::ScalarType;
|
@@ -69,26 +118,20 @@ Tensor& opt_add_out(
|
69 | 118 |
|
70 | 119 | ET_SWITCH_REALHB_TYPES(a_type, ctx, "add.out", CTYPE_A, [&]() {
|
71 | 120 | ET_SWITCH_REALHB_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() {
|
72 |
| - ET_SWITCH_REALB_TYPES(common_type, ctx, "add.out", CTYPE_IN, [&]() { |
73 |
| - ET_SWITCH_REALHB_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() { |
74 |
| - CTYPE_IN alpha_val; |
75 |
| - ET_KERNEL_CHECK( |
76 |
| - ctx, |
77 |
| - utils::extract_scalar(alpha, &alpha_val), |
78 |
| - InvalidArgument, ); |
79 |
| - |
80 |
| - apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>( |
81 |
| - [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) { |
82 |
| - CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a); |
83 |
| - CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b); |
84 |
| - CTYPE_IN value = a_casted + alpha_val * b_casted; |
85 |
| - |
86 |
| - return static_cast<CTYPE_OUT>(value); |
87 |
| - }, |
88 |
| - a, |
89 |
| - b, |
90 |
| - out); |
91 |
| - }); |
| 121 | + using CTYPE_IN = typename torch::executor:: |
| 122 | + promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type; |
| 123 | + ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type); |
| 124 | + ET_SWITCH_REALHB_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() { |
| 125 | + CTYPE_IN alpha_val; |
| 126 | + ET_KERNEL_CHECK( |
| 127 | + ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, ); |
| 128 | + |
| 129 | + AddInner< |
| 130 | + can_cast<CTYPE_IN, CTYPE_OUT>::value, |
| 131 | + CTYPE_A, |
| 132 | + CTYPE_B, |
| 133 | + CTYPE_IN, |
| 134 | + CTYPE_OUT>::run(a, b, alpha_val, out); |
92 | 135 | });
|
93 | 136 | });
|
94 | 137 | });
|
|
0 commit comments