Skip to content

Commit 1f3cb8a

Browse files
committed
elementwise_util: s/common/compute/ almost everywhere and deprecate SAME_AS_COMMON
As the title says, this is mostly a few related find-replaces, plus marking SupportedTensorDtypes::SAME_AS_COMMON deprecated. ghstack-source-id: 96ac25a ghstack-comment-id: 2752741465 Pull-Request-resolved: #9613
1 parent 811352d commit 1f3cb8a

File tree

4 files changed

+123
-117
lines changed

4 files changed

+123
-117
lines changed

kernels/portable/cpu/op_convolution.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ Tensor& convolution_out(
414414

415415
ET_SWITCH_REALH_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
416416
const auto load_bias = bias.has_value()
417-
? utils::internal::get_load_to_common_fn<CTYPE, name>(
417+
? utils::internal::get_load_to_compute_fn<CTYPE, name>(
418418
bias.value(), utils::SupportedTensorDtypes::REALHBF16)
419419
: nullptr;
420420
convolution_wrapper<CTYPE>(

kernels/portable/cpu/op_cumsum.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ Tensor& cumsum_out(
113113

114114
ET_SWITCH_REALHBBF16_TYPES(out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] {
115115
const auto load_self =
116-
utils::internal::get_load_to_common_fn<CTYPE_OUT, op_name>(
116+
utils::internal::get_load_to_compute_fn<CTYPE_OUT, op_name>(
117117
self, utils::SupportedTensorDtypes::REALHBBF16);
118118
cumsum_tensors<CTYPE_OUT>(self, load_self, dim, out);
119119
});

kernels/portable/cpu/util/dtype_util.h

Lines changed: 103 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -26,189 +26,189 @@ void convert_and_store(From f, void* dst) {
2626
*reinterpret_cast<To*>(dst) = static_cast<To>(f);
2727
}
2828

29-
template <typename CTYPE_COMMON>
30-
using load_to_common_fn = CTYPE_COMMON (*)(const void*);
29+
template <typename CTYPE_COMPUTE>
30+
using load_to_compute_fn = CTYPE_COMPUTE (*)(const void*);
3131

32-
template <typename CTYPE_COMMON, const char* op_name>
33-
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realhbbf16(
32+
template <typename CTYPE_COMPUTE, const char* op_name>
33+
load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_realhbbf16(
3434
const Tensor& t) {
35-
CTYPE_COMMON (*result)(const void*) = nullptr;
35+
CTYPE_COMPUTE (*result)(const void*) = nullptr;
3636
ET_SWITCH_REALHBBF16_TYPES(
3737
t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
38-
result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
38+
result = internal::load_and_convert<CTYPE_COMPUTE, TENSOR_CTYPE>;
3939
});
4040
return result;
4141
}
4242

43-
template <typename CTYPE_COMMON, const char* op_name>
44-
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realhbf16(
43+
template <typename CTYPE_COMPUTE, const char* op_name>
44+
load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_realhbf16(
4545
const Tensor& t) {
46-
CTYPE_COMMON (*result)(const void*) = nullptr;
46+
CTYPE_COMPUTE (*result)(const void*) = nullptr;
4747
ET_SWITCH_REALHBF16_TYPES(
4848
t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
49-
result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
49+
result = internal::load_and_convert<CTYPE_COMPUTE, TENSOR_CTYPE>;
5050
});
5151
return result;
5252
}
5353

54-
template <typename CTYPE_COMMON, const char* op_name>
55-
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_floathbf16(
54+
template <typename CTYPE_COMPUTE, const char* op_name>
55+
load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_floathbf16(
5656
const Tensor& t) {
57-
CTYPE_COMMON (*result)(const void*) = nullptr;
57+
CTYPE_COMPUTE (*result)(const void*) = nullptr;
5858
ET_SWITCH_FLOATHBF16_TYPES(
5959
t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
60-
result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
60+
result = internal::load_and_convert<CTYPE_COMPUTE, TENSOR_CTYPE>;
6161
});
6262
return result;
6363
}
6464

65-
template <typename CTYPE_COMMON, const char* op_name>
66-
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_intb(const Tensor& t) {
67-
CTYPE_COMMON (*result)(const void*) = nullptr;
65+
template <typename CTYPE_COMPUTE, const char* op_name>
66+
load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_intb(const Tensor& t) {
67+
CTYPE_COMPUTE (*result)(const void*) = nullptr;
6868
ET_SWITCH_INT_TYPES_AND(
6969
Bool, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
70-
result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
70+
result = internal::load_and_convert<CTYPE_COMPUTE, TENSOR_CTYPE>;
7171
});
7272
return result;
7373
}
7474

75-
template <typename CTYPE_COMMON, const char* op_name>
76-
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool_or_byte(
75+
template <typename CTYPE_COMPUTE, const char* op_name>
76+
load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_bool_or_byte(
7777
const Tensor& t) {
78-
CTYPE_COMMON (*result)(const void*) = nullptr;
78+
CTYPE_COMPUTE (*result)(const void*) = nullptr;
7979
ET_SWITCH_TWO_TYPES(
8080
Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
81-
result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
81+
result = internal::load_and_convert<CTYPE_COMPUTE, TENSOR_CTYPE>;
8282
});
8383
return result;
8484
}
8585

86-
template <typename CTYPE_COMMON, const char* op_name>
87-
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_same_as_compute(
86+
template <typename CTYPE_COMPUTE, const char* op_name>
87+
load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_same_as_compute(
8888
const Tensor& t) {
89-
constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMMON>::value;
89+
constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMPUTE>::value;
9090
ET_CHECK_MSG(
9191
t.scalar_type() == common_scalar_type,
9292
"Unhandled dtype %s for %s",
9393
::executorch::runtime::toString(common_scalar_type),
9494
op_name);
95-
return internal::load_and_convert<CTYPE_COMMON, CTYPE_COMMON>;
95+
return internal::load_and_convert<CTYPE_COMPUTE, CTYPE_COMPUTE>;
9696
}
9797

9898
template <
99-
typename CTYPE_COMMON,
99+
typename CTYPE_COMPUTE,
100100
const char* op_name,
101-
std::enable_if_t<std::is_same_v<CTYPE_COMMON, float>, bool> = true>
102-
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_same_as_common(
101+
std::enable_if_t<std::is_same_v<CTYPE_COMPUTE, float>, bool> = true>
102+
load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_same_as_common(
103103
const Tensor& t) {
104-
CTYPE_COMMON (*result)(const void*) = nullptr;
104+
CTYPE_COMPUTE (*result)(const void*) = nullptr;
105105
ET_SWITCH_THREE_TYPES(
106106
Float, Half, BFloat16, t.scalar_type(), unused, op_name, T, [&]() {
107-
result = internal::load_and_convert<CTYPE_COMMON, T>;
107+
result = internal::load_and_convert<CTYPE_COMPUTE, T>;
108108
});
109109
return result;
110110
}
111111

112112
template <
113-
typename CTYPE_COMMON,
113+
typename CTYPE_COMPUTE,
114114
const char* op_name,
115-
std::enable_if_t<!std::is_same_v<CTYPE_COMMON, float>, bool> = true>
116-
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_same_as_common(
115+
std::enable_if_t<!std::is_same_v<CTYPE_COMPUTE, float>, bool> = true>
116+
load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_same_as_common(
117117
const Tensor& t) {
118-
return get_load_to_common_fn_same_as_compute<CTYPE_COMMON, op_name>(t);
118+
return get_load_to_compute_fn_same_as_compute<CTYPE_COMPUTE, op_name>(t);
119119
}
120120

121-
template <typename CTYPE_COMMON>
122-
using store_common_to_tensor_fn = void (*)(CTYPE_COMMON, void*);
121+
template <typename CTYPE_COMPUTE>
122+
using store_compute_to_tensor_fn = void (*)(CTYPE_COMPUTE, void*);
123123

124-
template <typename CTYPE_COMMON, const char* op_name>
125-
store_common_to_tensor_fn<CTYPE_COMMON>
126-
get_store_common_to_tensor_fn_realhbbf16(const Tensor& t) {
127-
void (*result)(CTYPE_COMMON, void*) = nullptr;
124+
template <typename CTYPE_COMPUTE, const char* op_name>
125+
store_compute_to_tensor_fn<CTYPE_COMPUTE>
126+
get_store_compute_to_tensor_fn_realhbbf16(const Tensor& t) {
127+
void (*result)(CTYPE_COMPUTE, void*) = nullptr;
128128
ET_SWITCH_REALHBBF16_TYPES(
129129
t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
130-
result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
130+
result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMPUTE>;
131131
});
132132
return result;
133133
}
134134

135-
template <typename CTYPE_COMMON, const char* op_name>
136-
store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_realhbf16(
137-
const Tensor& t) {
138-
void (*result)(CTYPE_COMMON, void*) = nullptr;
135+
template <typename CTYPE_COMPUTE, const char* op_name>
136+
store_compute_to_tensor_fn<CTYPE_COMPUTE>
137+
get_store_compute_to_tensor_fn_realhbf16(const Tensor& t) {
138+
void (*result)(CTYPE_COMPUTE, void*) = nullptr;
139139
ET_SWITCH_REALHBF16_TYPES(
140140
t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
141-
result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
141+
result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMPUTE>;
142142
});
143143
return result;
144144
}
145145

146-
template <typename CTYPE_COMMON, const char* op_name>
147-
store_common_to_tensor_fn<CTYPE_COMMON>
148-
get_store_common_to_tensor_fn_floathbf16(const Tensor& t) {
149-
void (*result)(CTYPE_COMMON, void*) = nullptr;
146+
template <typename CTYPE_COMPUTE, const char* op_name>
147+
store_compute_to_tensor_fn<CTYPE_COMPUTE>
148+
get_store_compute_to_tensor_fn_floathbf16(const Tensor& t) {
149+
void (*result)(CTYPE_COMPUTE, void*) = nullptr;
150150
ET_SWITCH_FLOATHBF16_TYPES(
151151
t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
152-
result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
152+
result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMPUTE>;
153153
});
154154
return result;
155155
}
156156

157-
template <typename CTYPE_COMMON, const char* op_name>
158-
store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_intb(
157+
template <typename CTYPE_COMPUTE, const char* op_name>
158+
store_compute_to_tensor_fn<CTYPE_COMPUTE> get_store_compute_to_tensor_fn_intb(
159159
const Tensor& t) {
160-
void (*result)(CTYPE_COMMON, void*) = nullptr;
160+
void (*result)(CTYPE_COMPUTE, void*) = nullptr;
161161
ET_SWITCH_INT_TYPES_AND(
162162
Bool, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
163-
result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
163+
result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMPUTE>;
164164
});
165165
return result;
166166
}
167167

168-
template <typename CTYPE_COMMON, const char* op_name>
169-
store_common_to_tensor_fn<CTYPE_COMMON>
170-
get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) {
171-
void (*result)(CTYPE_COMMON, void*) = nullptr;
168+
template <typename CTYPE_COMPUTE, const char* op_name>
169+
store_compute_to_tensor_fn<CTYPE_COMPUTE>
170+
get_store_compute_to_tensor_fn_bool_or_byte(const Tensor& t) {
171+
void (*result)(CTYPE_COMPUTE, void*) = nullptr;
172172
ET_SWITCH_TWO_TYPES(
173173
Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
174-
result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
174+
result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMPUTE>;
175175
});
176176
return result;
177177
}
178178

179-
template <typename CTYPE_COMMON, const char* op_name>
180-
store_common_to_tensor_fn<CTYPE_COMMON>
181-
get_store_common_to_tensor_fn_same_as_compute(const Tensor& t) {
182-
constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMMON>::value;
179+
template <typename CTYPE_COMPUTE, const char* op_name>
180+
store_compute_to_tensor_fn<CTYPE_COMPUTE>
181+
get_store_compute_to_tensor_fn_same_as_compute(const Tensor& t) {
182+
constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMPUTE>::value;
183183
ET_CHECK_MSG(
184184
t.scalar_type() == common_scalar_type,
185185
"Unhandled dtype %s for %s",
186186
::executorch::runtime::toString(common_scalar_type),
187187
op_name);
188-
return internal::convert_and_store<CTYPE_COMMON, CTYPE_COMMON>;
188+
return internal::convert_and_store<CTYPE_COMPUTE, CTYPE_COMPUTE>;
189189
}
190190

191191
template <
192-
typename CTYPE_COMMON,
192+
typename CTYPE_COMPUTE,
193193
const char* op_name,
194-
std::enable_if_t<std::is_same_v<CTYPE_COMMON, float>, bool> = true>
195-
store_common_to_tensor_fn<CTYPE_COMMON>
196-
get_store_common_to_tensor_fn_same_as_common(const Tensor& t) {
197-
void (*result)(CTYPE_COMMON, void*) = nullptr;
194+
std::enable_if_t<std::is_same_v<CTYPE_COMPUTE, float>, bool> = true>
195+
store_compute_to_tensor_fn<CTYPE_COMPUTE>
196+
get_store_compute_to_tensor_fn_same_as_common(const Tensor& t) {
197+
void (*result)(CTYPE_COMPUTE, void*) = nullptr;
198198
ET_SWITCH_THREE_TYPES(
199199
Float, Half, BFloat16, t.scalar_type(), unused, op_name, CTYPE, [&]() {
200-
result = internal::convert_and_store<CTYPE, CTYPE_COMMON>;
200+
result = internal::convert_and_store<CTYPE, CTYPE_COMPUTE>;
201201
});
202202
return result;
203203
}
204204

205205
template <
206-
typename CTYPE_COMMON,
206+
typename CTYPE_COMPUTE,
207207
const char* op_name,
208-
std::enable_if_t<!std::is_same_v<CTYPE_COMMON, float>, bool> = true>
209-
store_common_to_tensor_fn<CTYPE_COMMON>
210-
get_store_common_to_tensor_fn_same_as_common(const Tensor& t) {
211-
return get_store_common_to_tensor_fn_same_as_compute<CTYPE_COMMON, op_name>(
208+
std::enable_if_t<!std::is_same_v<CTYPE_COMPUTE, float>, bool> = true>
209+
store_compute_to_tensor_fn<CTYPE_COMPUTE>
210+
get_store_compute_to_tensor_fn_same_as_common(const Tensor& t) {
211+
return get_store_compute_to_tensor_fn_same_as_compute<CTYPE_COMPUTE, op_name>(
212212
t);
213213
}
214214

@@ -220,59 +220,64 @@ enum class SupportedTensorDtypes {
220220
FLOATHBF16,
221221
INTB,
222222
BOOL_OR_BYTE,
223+
// DEPRECATED: not likely to be correct; use SAME_AS_COMMON.
223224
SAME_AS_COMPUTE,
224225
SAME_AS_COMMON,
225226
};
226227

227228
namespace internal {
228229

229-
template <typename CTYPE_COMMON, const char* op_name>
230-
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn(
230+
template <typename CTYPE_COMPUTE, const char* op_name>
231+
load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn(
231232
const Tensor& t,
232233
SupportedTensorDtypes dtypes) {
233234
switch (dtypes) {
234235
case SupportedTensorDtypes::REALHBBF16:
235-
return get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
236+
return get_load_to_compute_fn_realhbbf16<CTYPE_COMPUTE, op_name>(t);
236237
case SupportedTensorDtypes::REALHBF16:
237-
return get_load_to_common_fn_realhbf16<CTYPE_COMMON, op_name>(t);
238+
return get_load_to_compute_fn_realhbf16<CTYPE_COMPUTE, op_name>(t);
238239
case SupportedTensorDtypes::FLOATHBF16:
239-
return get_load_to_common_fn_realhbf16<CTYPE_COMMON, op_name>(t);
240+
return get_load_to_compute_fn_realhbf16<CTYPE_COMPUTE, op_name>(t);
240241
case SupportedTensorDtypes::INTB:
241-
return get_load_to_common_fn_intb<CTYPE_COMMON, op_name>(t);
242+
return get_load_to_compute_fn_intb<CTYPE_COMPUTE, op_name>(t);
242243
case SupportedTensorDtypes::BOOL_OR_BYTE:
243-
return get_load_to_common_fn_bool_or_byte<CTYPE_COMMON, op_name>(t);
244+
return get_load_to_compute_fn_bool_or_byte<CTYPE_COMPUTE, op_name>(t);
244245
case SupportedTensorDtypes::SAME_AS_COMPUTE:
245-
return get_load_to_common_fn_same_as_compute<CTYPE_COMMON, op_name>(t);
246+
return get_load_to_compute_fn_same_as_compute<CTYPE_COMPUTE, op_name>(t);
246247
case SupportedTensorDtypes::SAME_AS_COMMON:
247-
return get_load_to_common_fn_same_as_common<CTYPE_COMMON, op_name>(t);
248+
return get_load_to_compute_fn_same_as_common<CTYPE_COMPUTE, op_name>(t);
248249
}
249250
ET_CHECK(false);
250251
return nullptr;
251252
}
252253

253-
template <typename CTYPE_COMMON, const char* op_name>
254-
store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn(
254+
template <typename CTYPE_COMPUTE, const char* op_name>
255+
store_compute_to_tensor_fn<CTYPE_COMPUTE> get_store_compute_to_tensor_fn(
255256
const Tensor& t,
256257
SupportedTensorDtypes dtypes) {
257258
switch (dtypes) {
258259
case SupportedTensorDtypes::REALHBBF16:
259-
return get_store_common_to_tensor_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
260+
return get_store_compute_to_tensor_fn_realhbbf16<CTYPE_COMPUTE, op_name>(
261+
t);
260262
case SupportedTensorDtypes::REALHBF16:
261-
return get_store_common_to_tensor_fn_realhbf16<CTYPE_COMMON, op_name>(t);
263+
return get_store_compute_to_tensor_fn_realhbf16<CTYPE_COMPUTE, op_name>(
264+
t);
262265
case SupportedTensorDtypes::FLOATHBF16:
263-
return get_store_common_to_tensor_fn_floathbf16<CTYPE_COMMON, op_name>(t);
266+
return get_store_compute_to_tensor_fn_floathbf16<CTYPE_COMPUTE, op_name>(
267+
t);
264268
case SupportedTensorDtypes::INTB:
265-
return get_store_common_to_tensor_fn_intb<CTYPE_COMMON, op_name>(t);
269+
return get_store_compute_to_tensor_fn_intb<CTYPE_COMPUTE, op_name>(t);
266270
case SupportedTensorDtypes::BOOL_OR_BYTE:
267-
return get_store_common_to_tensor_fn_bool_or_byte<CTYPE_COMMON, op_name>(
268-
t);
271+
return get_store_compute_to_tensor_fn_bool_or_byte<
272+
CTYPE_COMPUTE,
273+
op_name>(t);
269274
case SupportedTensorDtypes::SAME_AS_COMPUTE:
270-
return get_store_common_to_tensor_fn_same_as_compute<
271-
CTYPE_COMMON,
275+
return get_store_compute_to_tensor_fn_same_as_compute<
276+
CTYPE_COMPUTE,
272277
op_name>(t);
273278
case SupportedTensorDtypes::SAME_AS_COMMON: {
274-
return get_store_common_to_tensor_fn_same_as_common<
275-
CTYPE_COMMON,
279+
return get_store_compute_to_tensor_fn_same_as_common<
280+
CTYPE_COMPUTE,
276281
op_name>(t);
277282
}
278283
}

0 commit comments

Comments
 (0)