Skip to content

Commit da94594

Browse files
Fix typo in sub & clean up (#3100) (#3253)
Summary: Pull Request resolved: #3100 Reviewed By: kirklandsign Differential Revision: D56255838 fbshipit-source-id: b6567320b557aeb287db66b43447db9caabebd13 (cherry picked from commit e69a662) Co-authored-by: Manuel Candales <[email protected]>
1 parent eabdeb0 commit da94594

File tree

2 files changed

+62
-60
lines changed

2 files changed

+62
-60
lines changed

kernels/portable/cpu/op_add.cpp

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,12 @@ Tensor& add_out(
4141
ET_KERNEL_CHECK(
4242
ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out);
4343

44-
ET_SWITCH_REALHB_TYPES(a_type, ctx, "add.out", CTYPE_A, [&]() {
45-
ET_SWITCH_REALHB_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() {
46-
ET_SWITCH_REALB_TYPES(common_type, ctx, "add.out", CTYPE_IN, [&]() {
47-
ET_SWITCH_REALHB_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() {
44+
constexpr auto name = "add.out";
45+
46+
ET_SWITCH_REALHB_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
47+
ET_SWITCH_REALHB_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
48+
ET_SWITCH_REALB_TYPES(common_type, ctx, name, CTYPE_IN, [&]() {
49+
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
4850
CTYPE_IN alpha_val;
4951
utils::extract_scalar(alpha, &alpha_val);
5052

@@ -99,29 +101,29 @@ Tensor& add_scalar_out(
99101
common_type = ScalarType::Float;
100102
}
101103

102-
ET_SWITCH_REALHB_TYPES(a_type, ctx, "add.Scalar_out", CTYPE_A, [&]() {
103-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() {
104-
ET_SWITCH_REALB_TYPES(
105-
common_type, ctx, "add.Scalar_out", CTYPE_IN, [&]() {
106-
ET_SWITCH_REALHB_TYPES(
107-
out_type, ctx, "add.Scalar_out", CTYPE_OUT, [&]() {
108-
CTYPE_B b_val;
109-
utils::extract_scalar(b, &b_val);
110-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
111-
CTYPE_IN alpha_val;
112-
utils::extract_scalar(alpha, &alpha_val);
113-
114-
apply_unary_map_fn(
115-
[b_casted, alpha_val](const CTYPE_A val_a) {
116-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
117-
CTYPE_IN value = a_casted + alpha_val * b_casted;
118-
return static_cast<CTYPE_OUT>(value);
119-
},
120-
a.const_data_ptr<CTYPE_A>(),
121-
out.mutable_data_ptr<CTYPE_OUT>(),
122-
out.numel());
123-
});
124-
});
104+
constexpr auto name = "add.Scalar_out";
105+
106+
ET_SWITCH_REALHB_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
107+
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
108+
ET_SWITCH_REALB_TYPES(common_type, ctx, name, CTYPE_IN, [&]() {
109+
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
110+
CTYPE_B b_val;
111+
utils::extract_scalar(b, &b_val);
112+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
113+
CTYPE_IN alpha_val;
114+
utils::extract_scalar(alpha, &alpha_val);
115+
116+
apply_unary_map_fn(
117+
[b_casted, alpha_val](const CTYPE_A val_a) {
118+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
119+
CTYPE_IN value = a_casted + alpha_val * b_casted;
120+
return static_cast<CTYPE_OUT>(value);
121+
},
122+
a.const_data_ptr<CTYPE_A>(),
123+
out.mutable_data_ptr<CTYPE_OUT>(),
124+
out.numel());
125+
});
126+
});
125127
});
126128
});
127129

kernels/portable/cpu/op_sub.cpp

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -29,23 +29,24 @@ Tensor& sub_out(
2929
InvalidArgument,
3030
out);
3131

32-
ET_KERNEL_CHECK(ctx, tensor_is_realhb_type(out), InvalidArgument, out);
32+
ET_KERNEL_CHECK(ctx, tensor_is_realh_type(out), InvalidArgument, out);
3333

3434
ScalarType a_type = a.scalar_type();
3535
ScalarType b_type = b.scalar_type();
3636
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
3737
ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true);
3838
ScalarType out_type = out.scalar_type();
3939

40+
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
4041
ET_KERNEL_CHECK(
4142
ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out);
42-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
43-
ET_KERNEL_CHECK(ctx, tensor_is_realh_type(out), InvalidArgument, out);
4443

45-
ET_SWITCH_REALH_TYPES(a_type, ctx, "sub.out", CTYPE_A, [&]() {
46-
ET_SWITCH_REALH_TYPES(b_type, ctx, "sub.out", CTYPE_B, [&]() {
47-
ET_SWITCH_REAL_TYPES(common_type, ctx, "sub.out", CTYPE_IN, [&]() {
48-
ET_SWITCH_REALH_TYPES(out_type, ctx, "sub.out", CTYPE_OUT, [&]() {
44+
constexpr auto name = "sub.out";
45+
46+
ET_SWITCH_REALH_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
47+
ET_SWITCH_REALH_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
48+
ET_SWITCH_REAL_TYPES(common_type, ctx, name, CTYPE_IN, [&]() {
49+
ET_SWITCH_REALH_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
4950
CTYPE_IN alpha_val;
5051
utils::extract_scalar(alpha, &alpha_val);
5152

@@ -84,11 +85,11 @@ Tensor& sub_scalar_out(
8485
out,
8586
"Failed to resize output tensor.");
8687

87-
ET_KERNEL_CHECK(ctx, tensor_is_realhb_type(out), InvalidArgument, out);
88+
ET_KERNEL_CHECK(ctx, tensor_is_realh_type(out), InvalidArgument, out);
8889

8990
ScalarType a_type = a.scalar_type();
9091
ScalarType b_type = utils::get_scalar_dtype(b);
91-
ScalarType alpha_type = utils::get_scalar_dtype(b);
92+
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
9293
ScalarType common_type =
9394
utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false);
9495
ScalarType out_type = out.scalar_type();
@@ -100,31 +101,30 @@ Tensor& sub_scalar_out(
100101
common_type = ScalarType::Float;
101102
}
102103

103-
ET_SWITCH_REALH_TYPES(a_type, ctx, "sub.Scalar_out", CTYPE_A, [&]() {
104-
ET_SWITCH_SCALAR_OBJ_REAL_TYPES(
105-
b_type, ctx, "sub.Scalar_out", CTYPE_B, [&]() {
106-
ET_SWITCH_REAL_TYPES(
107-
common_type, ctx, "sub.Scalar_out", CTYPE_IN, [&]() {
108-
ET_SWITCH_REALH_TYPES(
109-
out_type, ctx, "sub.Scalar_out", CTYPE_OUT, [&]() {
110-
CTYPE_B b_val;
111-
utils::extract_scalar(b, &b_val);
112-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
113-
CTYPE_IN alpha_val;
114-
utils::extract_scalar(alpha, &alpha_val);
115-
116-
apply_unary_map_fn(
117-
[b_casted, alpha_val](const CTYPE_A val_a) {
118-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
119-
CTYPE_IN value = a_casted - alpha_val * b_casted;
120-
return static_cast<CTYPE_OUT>(value);
121-
},
122-
a.const_data_ptr<CTYPE_A>(),
123-
out.mutable_data_ptr<CTYPE_OUT>(),
124-
out.numel());
125-
});
126-
});
104+
constexpr auto name = "sub.Scalar_out";
105+
106+
ET_SWITCH_REALH_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
107+
ET_SWITCH_SCALAR_OBJ_REAL_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
108+
ET_SWITCH_REAL_TYPES(common_type, ctx, name, CTYPE_IN, [&]() {
109+
ET_SWITCH_REALH_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
110+
CTYPE_B b_val;
111+
utils::extract_scalar(b, &b_val);
112+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
113+
CTYPE_IN alpha_val;
114+
utils::extract_scalar(alpha, &alpha_val);
115+
116+
apply_unary_map_fn(
117+
[b_casted, alpha_val](const CTYPE_A val_a) {
118+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
119+
CTYPE_IN value = a_casted - alpha_val * b_casted;
120+
return static_cast<CTYPE_OUT>(value);
121+
},
122+
a.const_data_ptr<CTYPE_A>(),
123+
out.mutable_data_ptr<CTYPE_OUT>(),
124+
out.numel());
127125
});
126+
});
127+
});
128128
});
129129

130130
return out;

0 commit comments

Comments
 (0)