@@ -29,23 +29,24 @@ Tensor& sub_out(
29
29
InvalidArgument,
30
30
out);
31
31
32
- ET_KERNEL_CHECK (ctx, tensor_is_realhb_type (out), InvalidArgument, out);
32
+ ET_KERNEL_CHECK (ctx, tensor_is_realh_type (out), InvalidArgument, out);
33
33
34
34
ScalarType a_type = a.scalar_type ();
35
35
ScalarType b_type = b.scalar_type ();
36
36
ScalarType alpha_type = utils::get_scalar_dtype (alpha);
37
37
ScalarType common_type = promoteTypes (a_type, b_type, /* half_to_float*/ true );
38
38
ScalarType out_type = out.scalar_type ();
39
39
40
+ ET_KERNEL_CHECK (ctx, canCast (common_type, out_type), InvalidArgument, out);
40
41
ET_KERNEL_CHECK (
41
42
ctx, check_alpha_type (alpha_type, common_type), InvalidArgument, out);
42
- ET_KERNEL_CHECK (ctx, canCast (common_type, out_type), InvalidArgument, out);
43
- ET_KERNEL_CHECK (ctx, tensor_is_realh_type (out), InvalidArgument, out);
44
43
45
- ET_SWITCH_REALH_TYPES (a_type, ctx, " sub.out" , CTYPE_A, [&]() {
46
- ET_SWITCH_REALH_TYPES (b_type, ctx, " sub.out" , CTYPE_B, [&]() {
47
- ET_SWITCH_REAL_TYPES (common_type, ctx, " sub.out" , CTYPE_IN, [&]() {
48
- ET_SWITCH_REALH_TYPES (out_type, ctx, " sub.out" , CTYPE_OUT, [&]() {
44
+ constexpr auto name = " sub.out" ;
45
+
46
+ ET_SWITCH_REALH_TYPES (a_type, ctx, name, CTYPE_A, [&]() {
47
+ ET_SWITCH_REALH_TYPES (b_type, ctx, name, CTYPE_B, [&]() {
48
+ ET_SWITCH_REAL_TYPES (common_type, ctx, name, CTYPE_IN, [&]() {
49
+ ET_SWITCH_REALH_TYPES (out_type, ctx, name, CTYPE_OUT, [&]() {
49
50
CTYPE_IN alpha_val;
50
51
utils::extract_scalar (alpha, &alpha_val);
51
52
@@ -84,11 +85,11 @@ Tensor& sub_scalar_out(
84
85
out,
85
86
" Failed to resize output tensor." );
86
87
87
- ET_KERNEL_CHECK (ctx, tensor_is_realhb_type (out), InvalidArgument, out);
88
+ ET_KERNEL_CHECK (ctx, tensor_is_realh_type (out), InvalidArgument, out);
88
89
89
90
ScalarType a_type = a.scalar_type ();
90
91
ScalarType b_type = utils::get_scalar_dtype (b);
91
- ScalarType alpha_type = utils::get_scalar_dtype (b );
92
+ ScalarType alpha_type = utils::get_scalar_dtype (alpha );
92
93
ScalarType common_type =
93
94
utils::promote_type_with_scalar (a_type, b, /* half_to_float*/ false );
94
95
ScalarType out_type = out.scalar_type ();
@@ -100,31 +101,30 @@ Tensor& sub_scalar_out(
100
101
common_type = ScalarType::Float;
101
102
}
102
103
103
- ET_SWITCH_REALH_TYPES (a_type, ctx, " sub.Scalar_out" , CTYPE_A, [&]() {
104
- ET_SWITCH_SCALAR_OBJ_REAL_TYPES (
105
- b_type, ctx, " sub.Scalar_out" , CTYPE_B, [&]() {
106
- ET_SWITCH_REAL_TYPES (
107
- common_type, ctx, " sub.Scalar_out" , CTYPE_IN, [&]() {
108
- ET_SWITCH_REALH_TYPES (
109
- out_type, ctx, " sub.Scalar_out" , CTYPE_OUT, [&]() {
110
- CTYPE_B b_val;
111
- utils::extract_scalar (b, &b_val);
112
- CTYPE_IN b_casted = static_cast <CTYPE_IN>(b_val);
113
- CTYPE_IN alpha_val;
114
- utils::extract_scalar (alpha, &alpha_val);
115
-
116
- apply_unary_map_fn (
117
- [b_casted, alpha_val](const CTYPE_A val_a) {
118
- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
119
- CTYPE_IN value = a_casted - alpha_val * b_casted;
120
- return static_cast <CTYPE_OUT>(value);
121
- },
122
- a.const_data_ptr <CTYPE_A>(),
123
- out.mutable_data_ptr <CTYPE_OUT>(),
124
- out.numel ());
125
- });
126
- });
104
+ constexpr auto name = " sub.Scalar_out" ;
105
+
106
+ ET_SWITCH_REALH_TYPES (a_type, ctx, name, CTYPE_A, [&]() {
107
+ ET_SWITCH_SCALAR_OBJ_REAL_TYPES (b_type, ctx, name, CTYPE_B, [&]() {
108
+ ET_SWITCH_REAL_TYPES (common_type, ctx, name, CTYPE_IN, [&]() {
109
+ ET_SWITCH_REALH_TYPES (out_type, ctx, name, CTYPE_OUT, [&]() {
110
+ CTYPE_B b_val;
111
+ utils::extract_scalar (b, &b_val);
112
+ CTYPE_IN b_casted = static_cast <CTYPE_IN>(b_val);
113
+ CTYPE_IN alpha_val;
114
+ utils::extract_scalar (alpha, &alpha_val);
115
+
116
+ apply_unary_map_fn (
117
+ [b_casted, alpha_val](const CTYPE_A val_a) {
118
+ CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
119
+ CTYPE_IN value = a_casted - alpha_val * b_casted;
120
+ return static_cast <CTYPE_OUT>(value);
121
+ },
122
+ a.const_data_ptr <CTYPE_A>(),
123
+ out.mutable_data_ptr <CTYPE_OUT>(),
124
+ out.numel ());
127
125
});
126
+ });
127
+ });
128
128
});
129
129
130
130
return out;
0 commit comments