Skip to content

Commit da4308f

Browse files
committed
RFC: Compute only in int32/long/float/double for portable ops to save size
Concern: what if we're running on some sort of 16-bit microcontroller where this is a pessimization? ghstack-source-id: a91380c ghstack-comment-id: 2752794343 Pull-Request-resolved: #9635
1 parent 16e745f commit da4308f

File tree

13 files changed

+307
-264
lines changed

13 files changed

+307
-264
lines changed

kernels/portable/cpu/op_add.cpp

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,21 @@ Tensor& add_out(
5050
// @lint-ignore CLANGTIDY facebook-hte-CArray
5151
static constexpr const char op_name[] = "add.out";
5252

53-
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
54-
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
55-
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
56-
[val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
57-
return val_a + val_alpha * val_b;
58-
},
59-
ctx,
60-
a,
61-
utils::SupportedTensorDtypes::REALHBBF16,
62-
b,
63-
utils::SupportedTensorDtypes::REALHBBF16,
64-
out,
65-
utils::SupportedTensorDtypes::REALHBBF16);
66-
});
53+
ET_SWITCH_ELEMENTWISE_COMPUTE_TYPES(
54+
compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
55+
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
56+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
57+
[val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
58+
return val_a + val_alpha * val_b;
59+
},
60+
ctx,
61+
a,
62+
utils::SupportedTensorDtypes::REALHBBF16,
63+
b,
64+
utils::SupportedTensorDtypes::REALHBBF16,
65+
out,
66+
utils::SupportedTensorDtypes::REALHBBF16);
67+
});
6768

6869
return out;
6970
}
@@ -99,19 +100,20 @@ Tensor& add_scalar_out(
99100
// @lint-ignore CLANGTIDY facebook-hte-CArray
100101
static constexpr const char op_name[] = "add.Scalar_out";
101102

102-
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
103-
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
104-
[b, alpha](const CTYPE_COMPUTE val_a) {
105-
CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
106-
CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
107-
return val_a + val_alpha * val_b;
108-
},
109-
ctx,
110-
a,
111-
utils::SupportedTensorDtypes::REALHBBF16,
112-
out,
113-
utils::SupportedTensorDtypes::SAME_AS_COMMON);
114-
});
103+
ET_SWITCH_ELEMENTWISE_COMPUTE_TYPES(
104+
compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
105+
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
106+
[b, alpha](const CTYPE_COMPUTE val_a) {
107+
CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
108+
CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
109+
return val_a + val_alpha * val_b;
110+
},
111+
ctx,
112+
a,
113+
utils::SupportedTensorDtypes::REALHBBF16,
114+
out,
115+
utils::SupportedTensorDtypes::SAME_AS_COMMON);
116+
});
115117

116118
return out;
117119
}

kernels/portable/cpu/op_clamp.cpp

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -133,26 +133,27 @@ Tensor& clamp_out(
133133
// @lint-ignore CLANGTIDY facebook-hte-CArray
134134
static constexpr const char op_name[] = "clamp.out";
135135

136-
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
137-
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
138-
[has_min, min_opt, has_max, max_opt](const CTYPE_COMPUTE val_in) {
139-
CTYPE_COMPUTE val_out = val_in;
140-
if (has_min) {
141-
val_out = utils::max_override(
142-
val_out, utils::scalar_to<CTYPE_COMPUTE>(min_opt.value()));
143-
}
144-
if (has_max) {
145-
val_out = utils::min_override(
146-
val_out, utils::scalar_to<CTYPE_COMPUTE>(max_opt.value()));
147-
}
148-
return val_out;
149-
},
150-
ctx,
151-
in,
152-
utils::SupportedTensorDtypes::REALHBBF16,
153-
out,
154-
utils::SupportedTensorDtypes::SAME_AS_COMMON);
155-
});
136+
ET_SWITCH_ELEMENTWISE_COMPUTE_TYPES(
137+
compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
138+
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
139+
[has_min, min_opt, has_max, max_opt](const CTYPE_COMPUTE val_in) {
140+
CTYPE_COMPUTE val_out = val_in;
141+
if (has_min) {
142+
val_out = utils::max_override(
143+
val_out, utils::scalar_to<CTYPE_COMPUTE>(min_opt.value()));
144+
}
145+
if (has_max) {
146+
val_out = utils::min_override(
147+
val_out, utils::scalar_to<CTYPE_COMPUTE>(max_opt.value()));
148+
}
149+
return val_out;
150+
},
151+
ctx,
152+
in,
153+
utils::SupportedTensorDtypes::REALHBBF16,
154+
out,
155+
utils::SupportedTensorDtypes::SAME_AS_COMMON);
156+
});
156157

157158
return out;
158159
}

kernels/portable/cpu/op_div.cpp

Lines changed: 49 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -121,34 +121,36 @@ Tensor& div_out_mode(
121121
const bool mode_is_trunc = mode_val == "trunc";
122122
bool div_by_zero_error = false;
123123

124-
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
125-
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
126-
[mode_is_trunc, &div_by_zero_error](
127-
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
128-
if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
129-
if (val_b == 0) {
130-
div_by_zero_error = true;
131-
return static_cast<CTYPE_COMPUTE>(0);
132-
}
133-
}
134-
CTYPE_COMPUTE value = val_a / val_b;
135-
if (mode_is_trunc) {
136-
value = std::trunc(value);
137-
} else {
138-
// We established above that the mode is either trunc or floor, so
139-
// it must be floor.
140-
value = utils::floor_divide(val_a, val_b);
141-
}
142-
return value;
143-
},
144-
ctx,
145-
a,
146-
utils::SupportedTensorDtypes::REALHBBF16,
147-
b,
148-
utils::SupportedTensorDtypes::REALHBBF16,
149-
out,
150-
utils::SupportedTensorDtypes::REALHBF16);
151-
});
124+
ET_SWITCH_ELEMENTWISE_COMPUTE_TYPES(
125+
compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
126+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
127+
[mode_is_trunc, &div_by_zero_error](
128+
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
129+
if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::
130+
value) {
131+
if (val_b == 0) {
132+
div_by_zero_error = true;
133+
return static_cast<CTYPE_COMPUTE>(0);
134+
}
135+
}
136+
CTYPE_COMPUTE value = val_a / val_b;
137+
if (mode_is_trunc) {
138+
value = std::trunc(value);
139+
} else {
140+
// We established above that the mode is either trunc or floor,
141+
// so it must be floor.
142+
value = utils::floor_divide(val_a, val_b);
143+
}
144+
return value;
145+
},
146+
ctx,
147+
a,
148+
utils::SupportedTensorDtypes::REALHBBF16,
149+
b,
150+
utils::SupportedTensorDtypes::REALHBBF16,
151+
out,
152+
utils::SupportedTensorDtypes::REALHBF16);
153+
});
152154

153155
ET_KERNEL_CHECK_MSG(
154156
ctx,
@@ -252,24 +254,25 @@ Tensor& div_scalar_mode_out(
252254
// @lint-ignore CLANGTIDY facebook-hte-CArray
253255
static constexpr const char op_name[] = "div.Scalar_mode_out";
254256

255-
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
256-
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
257-
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
258-
[val_b, mode_is_trunc](const CTYPE_COMPUTE val_a) {
259-
CTYPE_COMPUTE value = val_a / val_b;
260-
if (mode_is_trunc) {
261-
value = std::trunc(value);
262-
} else {
263-
value = utils::floor_divide(val_a, val_b);
264-
}
265-
return value;
266-
},
267-
ctx,
268-
a,
269-
utils::SupportedTensorDtypes::REALHBBF16,
270-
out,
271-
utils::SupportedTensorDtypes::REALHBF16);
272-
});
257+
ET_SWITCH_ELEMENTWISE_COMPUTE_TYPES(
258+
compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
259+
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
260+
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
261+
[val_b, mode_is_trunc](const CTYPE_COMPUTE val_a) {
262+
CTYPE_COMPUTE value = val_a / val_b;
263+
if (mode_is_trunc) {
264+
value = std::trunc(value);
265+
} else {
266+
value = utils::floor_divide(val_a, val_b);
267+
}
268+
return value;
269+
},
270+
ctx,
271+
a,
272+
utils::SupportedTensorDtypes::REALHBBF16,
273+
out,
274+
utils::SupportedTensorDtypes::REALHBF16);
275+
});
273276

274277
return out;
275278
}

kernels/portable/cpu/op_floor_divide.cpp

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -52,26 +52,28 @@ Tensor& floor_divide_out(
5252

5353
bool div_by_zero_error = false;
5454

55-
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
56-
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
57-
[&div_by_zero_error](
58-
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
59-
if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
60-
if (val_b == 0) {
61-
div_by_zero_error = true;
62-
return static_cast<CTYPE_COMPUTE>(0);
63-
}
64-
}
65-
return utils::floor_divide(val_a, val_b);
66-
},
67-
ctx,
68-
a,
69-
utils::SupportedTensorDtypes::REALHBBF16,
70-
b,
71-
utils::SupportedTensorDtypes::REALHBBF16,
72-
out,
73-
utils::SupportedTensorDtypes::REALHBF16);
74-
});
55+
ET_SWITCH_ELEMENTWISE_COMPUTE_TYPES(
56+
compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
57+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
58+
[&div_by_zero_error](
59+
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
60+
if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::
61+
value) {
62+
if (val_b == 0) {
63+
div_by_zero_error = true;
64+
return static_cast<CTYPE_COMPUTE>(0);
65+
}
66+
}
67+
return utils::floor_divide(val_a, val_b);
68+
},
69+
ctx,
70+
a,
71+
utils::SupportedTensorDtypes::REALHBBF16,
72+
b,
73+
utils::SupportedTensorDtypes::REALHBBF16,
74+
out,
75+
utils::SupportedTensorDtypes::REALHBF16);
76+
});
7577

7678
ET_KERNEL_CHECK_MSG(
7779
ctx,

kernels/portable/cpu/op_maximum.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,19 +44,20 @@ Tensor& maximum_out(
4444
// @lint-ignore CLANGTIDY facebook-hte-CArray
4545
static constexpr const char op_name[] = "maximum.out";
4646

47-
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
48-
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
49-
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
50-
return utils::max_override(val_a, val_b);
51-
},
52-
ctx,
53-
a,
54-
utils::SupportedTensorDtypes::REALHBBF16,
55-
b,
56-
utils::SupportedTensorDtypes::REALHBBF16,
57-
out,
58-
utils::SupportedTensorDtypes::REALHBBF16);
59-
});
47+
ET_SWITCH_ELEMENTWISE_COMPUTE_TYPES(
48+
compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
49+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
50+
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
51+
return utils::max_override(val_a, val_b);
52+
},
53+
ctx,
54+
a,
55+
utils::SupportedTensorDtypes::REALHBBF16,
56+
b,
57+
utils::SupportedTensorDtypes::REALHBBF16,
58+
out,
59+
utils::SupportedTensorDtypes::REALHBBF16);
60+
});
6061

6162
return out;
6263
}

kernels/portable/cpu/op_minimum.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,19 +44,20 @@ Tensor& minimum_out(
4444
// @lint-ignore CLANGTIDY facebook-hte-CArray
4545
static constexpr const char op_name[] = "minimum.out";
4646

47-
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
48-
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
49-
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
50-
return utils::min_override(val_a, val_b);
51-
},
52-
ctx,
53-
a,
54-
utils::SupportedTensorDtypes::REALHBBF16,
55-
b,
56-
utils::SupportedTensorDtypes::REALHBBF16,
57-
out,
58-
utils::SupportedTensorDtypes::REALHBBF16);
59-
});
47+
ET_SWITCH_ELEMENTWISE_COMPUTE_TYPES(
48+
compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
49+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
50+
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
51+
return utils::min_override(val_a, val_b);
52+
},
53+
ctx,
54+
a,
55+
utils::SupportedTensorDtypes::REALHBBF16,
56+
b,
57+
utils::SupportedTensorDtypes::REALHBBF16,
58+
out,
59+
utils::SupportedTensorDtypes::REALHBBF16);
60+
});
6061

6162
return out;
6263
}

0 commit comments

Comments
 (0)