-
Notifications
You must be signed in to change notification settings - Fork 530
/
Copy pathoperator_registry.h
299 lines (262 loc) · 9.17 KB
/
operator_registry.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <cstring>
#include <executorch/runtime/core/array_ref.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/result.h>
#include <executorch/runtime/core/span.h>
#include <executorch/runtime/platform/compiler.h>
#include <executorch/runtime/platform/platform.h>
// Debug switch for operator registry
#if defined(ET_OP_REGISTRY_DEBUG)
#include <ostream>
#endif
#define ET_LOG_KERNEL_KEY(k) \
ET_LOG( \
Info, \
"key: %s, is_fallback: %s", \
k.data(), \
k.is_fallback() ? "true" : "false");
#define ET_LOG_TENSOR_META(meta_list) \
for (const auto& meta : meta_list) { \
ET_LOG(Info, "dtype: %d | dim order: [", int(meta.dtype_)); \
for (size_t i = 0; i < meta.dim_order_.size(); i++) { \
ET_LOG(Info, "%d,", static_cast<int32_t>(meta.dim_order_[i])); \
} \
ET_LOG(Info, "]"); \
}
namespace executorch {
namespace ET_RUNTIME_NAMESPACE {
class KernelRuntimeContext; // Forward declaration
using OpFunction = void (*)(KernelRuntimeContext&, EValue**);
/**
* Dtype and dim order metadata for a Tensor argument to an operator.
* Used by the Executor to hold the tensor metadata info and retrieve kernel.
*/
struct TensorMeta {
executorch::aten::ScalarType dtype_;
Span<executorch::aten::DimOrderType> dim_order_;
TensorMeta() = default;
TensorMeta(
executorch::aten::ScalarType dtype,
Span<executorch::aten::DimOrderType> order)
: dtype_(dtype), dim_order_(order) {}
bool operator==(const TensorMeta& other) const {
return this->equals(other);
}
bool operator!=(const TensorMeta& other) const {
return !this->equals(other);
}
bool equals(const TensorMeta& other) const {
if (dtype_ != other.dtype_) {
return false;
}
if (dim_order_.size() != other.dim_order_.size()) {
return false;
}
for (size_t i = 0; i < dim_order_.size(); i++) {
if (dim_order_[i] != other.dim_order_[i]) {
return false;
}
}
return true;
}
#if defined(ET_OP_REGISTRY_DEBUG)
friend std::ostream& operator<<(std::ostream& os, const TensorMeta& meta) {
os << "dtype: " << int(meta.dtype_) << " | dim order: [";
for (int i = 0; i < meta.dim_order_.size(); i++) {
os << static_cast<int32_t>(meta.dim_order_[i]) << ", ";
}
os << "]";
return os;
}
#endif
};
/**
* Describes which dtype & dim order specialized kernel to be bound to an
* operator.
*
* Kernel key data is a string with the format:
*
* "v<version>/<tensor_meta>|<tensor_meta>..."
*
* The version is v1 for now. If the kernel key format changes, update the
* version to avoid breaking pre-existing kernel keys.
*
* Each tensor_meta has the following format: "<dtype>;<dim_order,...>"
*
* Example kernel key data: "v1/7;0,1,2,3|1;0,1,2,3,4,5,6,7"
*
* This has two tensors: the first with dtype=7 and dim order 0,1,2,3, and the
* second with dtype=1 and dim order 0,1,2,3,4,5,6,7.
*
* IMPORTANT:
* Users should not construct a kernel key manually. Instead, it should be
* generated from kernel yaml.
*/
struct KernelKey {
public:
/**
* Creates a fallback (non-specialized) kernel key: this kernel can be used
* for all input tensor dtypes and dim orders if the specialized kernel is not
* registered.
*/
KernelKey() : is_fallback_(true) {}
/**
* Creates a specialized (non-fallback) kernel key that matches a specific
* set of input tensor dtypes and dim orders. See the class comment for the
* expected format of `kernel_key_data`.
*/
/* implicit */ KernelKey(const char* kernel_key_data)
: kernel_key_data_(kernel_key_data), is_fallback_(false) {}
bool operator==(const KernelKey& other) const {
return this->equals(other);
}
bool operator!=(const KernelKey& other) const {
return !this->equals(other);
}
bool equals(const KernelKey& other) const {
if (is_fallback_ != other.is_fallback_) {
return false;
}
if (is_fallback_) {
return true;
}
return strcmp(kernel_key_data_, other.kernel_key_data_) == 0;
}
bool is_fallback() const {
return is_fallback_;
}
const char* data() const {
return kernel_key_data_;
}
#if defined(ET_OP_REGISTRY_DEBUG)
friend std::ostream& operator<<(std::ostream& os, const KernelKey& key) {
os << key.kernel_key_data_ << std::endl;
return os;
}
#endif
private:
const char* kernel_key_data_ = nullptr;
bool is_fallback_;
};
/**
* Struct that bundles a kernel key, a function and an op name together. An
* `Operator` may have more than one `Kernel` (maximum kMaxNumOfKernelPerOp) and
* they should have the same op name and different kernel key. A "fallback"
* kernel may or may not live in an `Operator`.
*/
struct Kernel {
const char* name_;
// String representation of kernel key, with the same format as
// KernelKey.to_string_representation()
// Data is not owned by the Kernel struct.
KernelKey kernel_key_;
OpFunction op_;
/**
* We are doing a copy of the string pointer instead of duplicating the string
* itself, we require the lifetime of the operator name to be at least as long
* as the operator registry.
*/
explicit Kernel(const char* name, OpFunction func) : name_(name), op_(func) {}
explicit Kernel(const char* name, KernelKey key, OpFunction func)
: name_(name), kernel_key_(key), op_(func) {}
Kernel() {}
};
namespace internal {
/**
* A make_kernel_key_string buffer size that is large enough to hold a kernel
* key string with 16 tensors of 16 dimensions, plus the trailing NUL byte.
*/
constexpr size_t kKernelKeyBufSize = 659;
/**
* Given the list of input tensor dtypes + dim orders, writes the kernel key
* string into the buffer. Returns an error if the buffer is too small or if the
* tensors cannot be represented as a valid key string.
*/
Error make_kernel_key_string(
Span<const TensorMeta> key,
char* buf,
size_t buf_size);
} // namespace internal
/**
* Checks whether an operator exists with a given name and TensorMeta list. When
* TensorMeta is empty, it means this op does not have specialized kernels, so
* it checks whether it has any fallback kernels.
*/
bool registry_has_op_function(
const char* name,
Span<const TensorMeta> meta_list = {});
/**
* Returns the operator with a given name and TensorMeta list, if present.
*/
::executorch::runtime::Result<OpFunction> get_op_function_from_registry(
const char* name,
Span<const TensorMeta> meta_list = {});
/**
* Returns all registered kernels.
*/
Span<const Kernel> get_registered_kernels();
/**
* Registers the provided kernels.
*
* @param[in] kernels Kernel objects to register.
* @retval Error::Ok always. Panics on error. This function needs to return a
* non-void type to run at static initialization time.
*/
ET_NODISCARD Error register_kernels(const Span<const Kernel>);
/**
* Registers a single kernel.
*
* @param[in] kernel Kernel object to register.
* @retval Error::Ok always. Panics on error. This function needs to return a
* non-void type to run at static initialization time.
*/
ET_NODISCARD inline Error register_kernel(const Kernel& kernel) {
return register_kernels({&kernel, 1});
};
} // namespace ET_RUNTIME_NAMESPACE
} // namespace executorch
namespace torch {
namespace executor {
// TODO(T197294990): Remove these deprecated aliases once all users have moved
// to the new `::executorch` namespaces.
using ::executorch::ET_RUNTIME_NAMESPACE::Kernel;
using ::executorch::ET_RUNTIME_NAMESPACE::KernelKey;
using ::executorch::ET_RUNTIME_NAMESPACE::KernelRuntimeContext;
using ::executorch::ET_RUNTIME_NAMESPACE::OpFunction;
using ::executorch::ET_RUNTIME_NAMESPACE::TensorMeta;
using KernelRuntimeContext =
::executorch::ET_RUNTIME_NAMESPACE::KernelRuntimeContext;
inline ::executorch::runtime::Error register_kernels(ArrayRef<Kernel> kernels) {
return ::executorch::ET_RUNTIME_NAMESPACE::register_kernels(
{kernels.data(), kernels.size()});
}
inline OpFunction getOpsFn(
const char* name,
ArrayRef<TensorMeta> meta_list = {}) {
auto result =
::executorch::ET_RUNTIME_NAMESPACE::get_op_function_from_registry(
name, {meta_list.data(), meta_list.size()});
ET_CHECK(result.ok()); // get_op_function_from_registry() logs details.
return *result;
}
inline bool hasOpsFn(const char* name, ArrayRef<TensorMeta> meta_list = {}) {
return ::executorch::ET_RUNTIME_NAMESPACE::registry_has_op_function(
name, {meta_list.data(), meta_list.size()});
}
inline ArrayRef<Kernel> get_kernels() {
Span<const Kernel> kernels =
::executorch::ET_RUNTIME_NAMESPACE::get_registered_kernels();
return ArrayRef<Kernel>(kernels.data(), kernels.size());
}
} // namespace executor
} // namespace torch