@@ -26,189 +26,189 @@ void convert_and_store(From f, void* dst) {
26
26
*reinterpret_cast <To*>(dst) = static_cast <To>(f);
27
27
}
28
28
29
- template <typename CTYPE_COMMON >
30
- using load_to_common_fn = CTYPE_COMMON (*)(const void *);
29
+ template <typename CTYPE_COMPUTE >
30
+ using load_to_compute_fn = CTYPE_COMPUTE (*)(const void *);
31
31
32
- template <typename CTYPE_COMMON , const char * op_name>
33
- load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realhbbf16 (
32
+ template <typename CTYPE_COMPUTE , const char * op_name>
33
+ load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_realhbbf16 (
34
34
const Tensor& t) {
35
- CTYPE_COMMON (*result)(const void *) = nullptr ;
35
+ CTYPE_COMPUTE (*result)(const void *) = nullptr ;
36
36
ET_SWITCH_REALHBBF16_TYPES (
37
37
t.scalar_type (), unused, op_name, TENSOR_CTYPE, [&]() {
38
- result = internal::load_and_convert<CTYPE_COMMON , TENSOR_CTYPE>;
38
+ result = internal::load_and_convert<CTYPE_COMPUTE , TENSOR_CTYPE>;
39
39
});
40
40
return result;
41
41
}
42
42
43
- template <typename CTYPE_COMMON , const char * op_name>
44
- load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realhbf16 (
43
+ template <typename CTYPE_COMPUTE , const char * op_name>
44
+ load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_realhbf16 (
45
45
const Tensor& t) {
46
- CTYPE_COMMON (*result)(const void *) = nullptr ;
46
+ CTYPE_COMPUTE (*result)(const void *) = nullptr ;
47
47
ET_SWITCH_REALHBF16_TYPES (
48
48
t.scalar_type (), unused, op_name, TENSOR_CTYPE, [&]() {
49
- result = internal::load_and_convert<CTYPE_COMMON , TENSOR_CTYPE>;
49
+ result = internal::load_and_convert<CTYPE_COMPUTE , TENSOR_CTYPE>;
50
50
});
51
51
return result;
52
52
}
53
53
54
- template <typename CTYPE_COMMON , const char * op_name>
55
- load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_floathbf16 (
54
+ template <typename CTYPE_COMPUTE , const char * op_name>
55
+ load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_floathbf16 (
56
56
const Tensor& t) {
57
- CTYPE_COMMON (*result)(const void *) = nullptr ;
57
+ CTYPE_COMPUTE (*result)(const void *) = nullptr ;
58
58
ET_SWITCH_FLOATHBF16_TYPES (
59
59
t.scalar_type (), unused, op_name, TENSOR_CTYPE, [&]() {
60
- result = internal::load_and_convert<CTYPE_COMMON , TENSOR_CTYPE>;
60
+ result = internal::load_and_convert<CTYPE_COMPUTE , TENSOR_CTYPE>;
61
61
});
62
62
return result;
63
63
}
64
64
65
- template <typename CTYPE_COMMON , const char * op_name>
66
- load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_intb (const Tensor& t) {
67
- CTYPE_COMMON (*result)(const void *) = nullptr ;
65
+ template <typename CTYPE_COMPUTE , const char * op_name>
66
+ load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_intb (const Tensor& t) {
67
+ CTYPE_COMPUTE (*result)(const void *) = nullptr ;
68
68
ET_SWITCH_INT_TYPES_AND (
69
69
Bool, t.scalar_type (), unused, op_name, TENSOR_CTYPE, [&]() {
70
- result = internal::load_and_convert<CTYPE_COMMON , TENSOR_CTYPE>;
70
+ result = internal::load_and_convert<CTYPE_COMPUTE , TENSOR_CTYPE>;
71
71
});
72
72
return result;
73
73
}
74
74
75
- template <typename CTYPE_COMMON , const char * op_name>
76
- load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool_or_byte (
75
+ template <typename CTYPE_COMPUTE , const char * op_name>
76
+ load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_bool_or_byte (
77
77
const Tensor& t) {
78
- CTYPE_COMMON (*result)(const void *) = nullptr ;
78
+ CTYPE_COMPUTE (*result)(const void *) = nullptr ;
79
79
ET_SWITCH_TWO_TYPES (
80
80
Bool, Byte , t.scalar_type (), unused, op_name, TENSOR_CTYPE, [&]() {
81
- result = internal::load_and_convert<CTYPE_COMMON , TENSOR_CTYPE>;
81
+ result = internal::load_and_convert<CTYPE_COMPUTE , TENSOR_CTYPE>;
82
82
});
83
83
return result;
84
84
}
85
85
86
- template <typename CTYPE_COMMON , const char * op_name>
87
- load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_same_as_compute (
86
+ template <typename CTYPE_COMPUTE , const char * op_name>
87
+ load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_same_as_compute (
88
88
const Tensor& t) {
89
- constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMMON >::value;
89
+ constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMPUTE >::value;
90
90
ET_CHECK_MSG (
91
91
t.scalar_type () == common_scalar_type,
92
92
" Unhandled dtype %s for %s" ,
93
93
::executorch::runtime::toString (common_scalar_type),
94
94
op_name);
95
- return internal::load_and_convert<CTYPE_COMMON, CTYPE_COMMON >;
95
+ return internal::load_and_convert<CTYPE_COMPUTE, CTYPE_COMPUTE >;
96
96
}
97
97
98
98
template <
99
- typename CTYPE_COMMON ,
99
+ typename CTYPE_COMPUTE ,
100
100
const char * op_name,
101
- std::enable_if_t <std::is_same_v<CTYPE_COMMON , float >, bool > = true >
102
- load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_same_as_common (
101
+ std::enable_if_t <std::is_same_v<CTYPE_COMPUTE , float >, bool > = true >
102
+ load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_same_as_common (
103
103
const Tensor& t) {
104
- CTYPE_COMMON (*result)(const void *) = nullptr ;
104
+ CTYPE_COMPUTE (*result)(const void *) = nullptr ;
105
105
ET_SWITCH_THREE_TYPES (
106
106
Float, Half, BFloat16, t.scalar_type (), unused, op_name, T, [&]() {
107
- result = internal::load_and_convert<CTYPE_COMMON , T>;
107
+ result = internal::load_and_convert<CTYPE_COMPUTE , T>;
108
108
});
109
109
return result;
110
110
}
111
111
112
112
template <
113
- typename CTYPE_COMMON ,
113
+ typename CTYPE_COMPUTE ,
114
114
const char * op_name,
115
- std::enable_if_t <!std::is_same_v<CTYPE_COMMON , float >, bool > = true >
116
- load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_same_as_common (
115
+ std::enable_if_t <!std::is_same_v<CTYPE_COMPUTE , float >, bool > = true >
116
+ load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_same_as_common (
117
117
const Tensor& t) {
118
- return get_load_to_common_fn_same_as_compute<CTYPE_COMMON , op_name>(t);
118
+ return get_load_to_compute_fn_same_as_compute<CTYPE_COMPUTE , op_name>(t);
119
119
}
120
120
121
- template <typename CTYPE_COMMON >
122
- using store_common_to_tensor_fn = void (*)(CTYPE_COMMON , void *);
121
+ template <typename CTYPE_COMPUTE >
122
+ using store_compute_to_tensor_fn = void (*)(CTYPE_COMPUTE , void *);
123
123
124
- template <typename CTYPE_COMMON , const char * op_name>
125
- store_common_to_tensor_fn<CTYPE_COMMON >
126
- get_store_common_to_tensor_fn_realhbbf16 (const Tensor& t) {
127
- void (*result)(CTYPE_COMMON , void *) = nullptr ;
124
+ template <typename CTYPE_COMPUTE , const char * op_name>
125
+ store_compute_to_tensor_fn<CTYPE_COMPUTE >
126
+ get_store_compute_to_tensor_fn_realhbbf16 (const Tensor& t) {
127
+ void (*result)(CTYPE_COMPUTE , void *) = nullptr ;
128
128
ET_SWITCH_REALHBBF16_TYPES (
129
129
t.scalar_type (), unused, op_name, TENSOR_CTYPE, [&]() {
130
- result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON >;
130
+ result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMPUTE >;
131
131
});
132
132
return result;
133
133
}
134
134
135
- template <typename CTYPE_COMMON , const char * op_name>
136
- store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_realhbf16 (
137
- const Tensor& t) {
138
- void (*result)(CTYPE_COMMON , void *) = nullptr ;
135
+ template <typename CTYPE_COMPUTE , const char * op_name>
136
+ store_compute_to_tensor_fn<CTYPE_COMPUTE>
137
+ get_store_compute_to_tensor_fn_realhbf16 ( const Tensor& t) {
138
+ void (*result)(CTYPE_COMPUTE , void *) = nullptr ;
139
139
ET_SWITCH_REALHBF16_TYPES (
140
140
t.scalar_type (), unused, op_name, TENSOR_CTYPE, [&]() {
141
- result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON >;
141
+ result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMPUTE >;
142
142
});
143
143
return result;
144
144
}
145
145
146
- template <typename CTYPE_COMMON , const char * op_name>
147
- store_common_to_tensor_fn<CTYPE_COMMON >
148
- get_store_common_to_tensor_fn_floathbf16 (const Tensor& t) {
149
- void (*result)(CTYPE_COMMON , void *) = nullptr ;
146
+ template <typename CTYPE_COMPUTE , const char * op_name>
147
+ store_compute_to_tensor_fn<CTYPE_COMPUTE >
148
+ get_store_compute_to_tensor_fn_floathbf16 (const Tensor& t) {
149
+ void (*result)(CTYPE_COMPUTE , void *) = nullptr ;
150
150
ET_SWITCH_FLOATHBF16_TYPES (
151
151
t.scalar_type (), unused, op_name, TENSOR_CTYPE, [&]() {
152
- result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON >;
152
+ result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMPUTE >;
153
153
});
154
154
return result;
155
155
}
156
156
157
- template <typename CTYPE_COMMON , const char * op_name>
158
- store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_intb (
157
+ template <typename CTYPE_COMPUTE , const char * op_name>
158
+ store_compute_to_tensor_fn<CTYPE_COMPUTE> get_store_compute_to_tensor_fn_intb (
159
159
const Tensor& t) {
160
- void (*result)(CTYPE_COMMON , void *) = nullptr ;
160
+ void (*result)(CTYPE_COMPUTE , void *) = nullptr ;
161
161
ET_SWITCH_INT_TYPES_AND (
162
162
Bool, t.scalar_type (), unused, op_name, TENSOR_CTYPE, [&]() {
163
- result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON >;
163
+ result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMPUTE >;
164
164
});
165
165
return result;
166
166
}
167
167
168
- template <typename CTYPE_COMMON , const char * op_name>
169
- store_common_to_tensor_fn<CTYPE_COMMON >
170
- get_store_common_to_tensor_fn_bool_or_byte (const Tensor& t) {
171
- void (*result)(CTYPE_COMMON , void *) = nullptr ;
168
+ template <typename CTYPE_COMPUTE , const char * op_name>
169
+ store_compute_to_tensor_fn<CTYPE_COMPUTE >
170
+ get_store_compute_to_tensor_fn_bool_or_byte (const Tensor& t) {
171
+ void (*result)(CTYPE_COMPUTE , void *) = nullptr ;
172
172
ET_SWITCH_TWO_TYPES (
173
173
Bool, Byte , t.scalar_type (), unused, op_name, TENSOR_CTYPE, [&]() {
174
- result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON >;
174
+ result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMPUTE >;
175
175
});
176
176
return result;
177
177
}
178
178
179
- template <typename CTYPE_COMMON , const char * op_name>
180
- store_common_to_tensor_fn<CTYPE_COMMON >
181
- get_store_common_to_tensor_fn_same_as_compute (const Tensor& t) {
182
- constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMMON >::value;
179
+ template <typename CTYPE_COMPUTE , const char * op_name>
180
+ store_compute_to_tensor_fn<CTYPE_COMPUTE >
181
+ get_store_compute_to_tensor_fn_same_as_compute (const Tensor& t) {
182
+ constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMPUTE >::value;
183
183
ET_CHECK_MSG (
184
184
t.scalar_type () == common_scalar_type,
185
185
" Unhandled dtype %s for %s" ,
186
186
::executorch::runtime::toString (common_scalar_type),
187
187
op_name);
188
- return internal::convert_and_store<CTYPE_COMMON, CTYPE_COMMON >;
188
+ return internal::convert_and_store<CTYPE_COMPUTE, CTYPE_COMPUTE >;
189
189
}
190
190
191
191
template <
192
- typename CTYPE_COMMON ,
192
+ typename CTYPE_COMPUTE ,
193
193
const char * op_name,
194
- std::enable_if_t <std::is_same_v<CTYPE_COMMON , float >, bool > = true >
195
- store_common_to_tensor_fn<CTYPE_COMMON >
196
- get_store_common_to_tensor_fn_same_as_common (const Tensor& t) {
197
- void (*result)(CTYPE_COMMON , void *) = nullptr ;
194
+ std::enable_if_t <std::is_same_v<CTYPE_COMPUTE , float >, bool > = true >
195
+ store_compute_to_tensor_fn<CTYPE_COMPUTE >
196
+ get_store_compute_to_tensor_fn_same_as_common (const Tensor& t) {
197
+ void (*result)(CTYPE_COMPUTE , void *) = nullptr ;
198
198
ET_SWITCH_THREE_TYPES (
199
199
Float, Half, BFloat16, t.scalar_type (), unused, op_name, CTYPE, [&]() {
200
- result = internal::convert_and_store<CTYPE, CTYPE_COMMON >;
200
+ result = internal::convert_and_store<CTYPE, CTYPE_COMPUTE >;
201
201
});
202
202
return result;
203
203
}
204
204
205
205
template <
206
- typename CTYPE_COMMON ,
206
+ typename CTYPE_COMPUTE ,
207
207
const char * op_name,
208
- std::enable_if_t <!std::is_same_v<CTYPE_COMMON , float >, bool > = true >
209
- store_common_to_tensor_fn<CTYPE_COMMON >
210
- get_store_common_to_tensor_fn_same_as_common (const Tensor& t) {
211
- return get_store_common_to_tensor_fn_same_as_compute<CTYPE_COMMON , op_name>(
208
+ std::enable_if_t <!std::is_same_v<CTYPE_COMPUTE , float >, bool > = true >
209
+ store_compute_to_tensor_fn<CTYPE_COMPUTE >
210
+ get_store_compute_to_tensor_fn_same_as_common (const Tensor& t) {
211
+ return get_store_compute_to_tensor_fn_same_as_compute<CTYPE_COMPUTE , op_name>(
212
212
t);
213
213
}
214
214
@@ -220,59 +220,64 @@ enum class SupportedTensorDtypes {
220
220
FLOATHBF16,
221
221
INTB,
222
222
BOOL_OR_BYTE,
223
+ // DEPRECATED: not likely to be correct; use SAME_AS_COMMON.
223
224
SAME_AS_COMPUTE,
224
225
SAME_AS_COMMON,
225
226
};
226
227
227
228
namespace internal {
228
229
229
- template <typename CTYPE_COMMON , const char * op_name>
230
- load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn (
230
+ template <typename CTYPE_COMPUTE , const char * op_name>
231
+ load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn (
231
232
const Tensor& t,
232
233
SupportedTensorDtypes dtypes) {
233
234
switch (dtypes) {
234
235
case SupportedTensorDtypes::REALHBBF16:
235
- return get_load_to_common_fn_realhbbf16<CTYPE_COMMON , op_name>(t);
236
+ return get_load_to_compute_fn_realhbbf16<CTYPE_COMPUTE , op_name>(t);
236
237
case SupportedTensorDtypes::REALHBF16:
237
- return get_load_to_common_fn_realhbf16<CTYPE_COMMON , op_name>(t);
238
+ return get_load_to_compute_fn_realhbf16<CTYPE_COMPUTE , op_name>(t);
238
239
case SupportedTensorDtypes::FLOATHBF16:
239
- return get_load_to_common_fn_realhbf16<CTYPE_COMMON , op_name>(t);
240
+ return get_load_to_compute_fn_realhbf16<CTYPE_COMPUTE , op_name>(t);
240
241
case SupportedTensorDtypes::INTB:
241
- return get_load_to_common_fn_intb<CTYPE_COMMON , op_name>(t);
242
+ return get_load_to_compute_fn_intb<CTYPE_COMPUTE , op_name>(t);
242
243
case SupportedTensorDtypes::BOOL_OR_BYTE:
243
- return get_load_to_common_fn_bool_or_byte<CTYPE_COMMON , op_name>(t);
244
+ return get_load_to_compute_fn_bool_or_byte<CTYPE_COMPUTE , op_name>(t);
244
245
case SupportedTensorDtypes::SAME_AS_COMPUTE:
245
- return get_load_to_common_fn_same_as_compute<CTYPE_COMMON , op_name>(t);
246
+ return get_load_to_compute_fn_same_as_compute<CTYPE_COMPUTE , op_name>(t);
246
247
case SupportedTensorDtypes::SAME_AS_COMMON:
247
- return get_load_to_common_fn_same_as_common<CTYPE_COMMON , op_name>(t);
248
+ return get_load_to_compute_fn_same_as_common<CTYPE_COMPUTE , op_name>(t);
248
249
}
249
250
ET_CHECK (false );
250
251
return nullptr ;
251
252
}
252
253
253
- template <typename CTYPE_COMMON , const char * op_name>
254
- store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn (
254
+ template <typename CTYPE_COMPUTE , const char * op_name>
255
+ store_compute_to_tensor_fn<CTYPE_COMPUTE> get_store_compute_to_tensor_fn (
255
256
const Tensor& t,
256
257
SupportedTensorDtypes dtypes) {
257
258
switch (dtypes) {
258
259
case SupportedTensorDtypes::REALHBBF16:
259
- return get_store_common_to_tensor_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
260
+ return get_store_compute_to_tensor_fn_realhbbf16<CTYPE_COMPUTE, op_name>(
261
+ t);
260
262
case SupportedTensorDtypes::REALHBF16:
261
- return get_store_common_to_tensor_fn_realhbf16<CTYPE_COMMON, op_name>(t);
263
+ return get_store_compute_to_tensor_fn_realhbf16<CTYPE_COMPUTE, op_name>(
264
+ t);
262
265
case SupportedTensorDtypes::FLOATHBF16:
263
- return get_store_common_to_tensor_fn_floathbf16<CTYPE_COMMON, op_name>(t);
266
+ return get_store_compute_to_tensor_fn_floathbf16<CTYPE_COMPUTE, op_name>(
267
+ t);
264
268
case SupportedTensorDtypes::INTB:
265
- return get_store_common_to_tensor_fn_intb<CTYPE_COMMON , op_name>(t);
269
+ return get_store_compute_to_tensor_fn_intb<CTYPE_COMPUTE , op_name>(t);
266
270
case SupportedTensorDtypes::BOOL_OR_BYTE:
267
- return get_store_common_to_tensor_fn_bool_or_byte<CTYPE_COMMON, op_name>(
268
- t);
271
+ return get_store_compute_to_tensor_fn_bool_or_byte<
272
+ CTYPE_COMPUTE,
273
+ op_name>(t);
269
274
case SupportedTensorDtypes::SAME_AS_COMPUTE:
270
- return get_store_common_to_tensor_fn_same_as_compute <
271
- CTYPE_COMMON ,
275
+ return get_store_compute_to_tensor_fn_same_as_compute <
276
+ CTYPE_COMPUTE ,
272
277
op_name>(t);
273
278
case SupportedTensorDtypes::SAME_AS_COMMON: {
274
- return get_store_common_to_tensor_fn_same_as_common <
275
- CTYPE_COMMON ,
279
+ return get_store_compute_to_tensor_fn_same_as_common <
280
+ CTYPE_COMPUTE ,
276
281
op_name>(t);
277
282
}
278
283
}
0 commit comments