@@ -96,39 +96,43 @@ struct TensorMeta {
96
96
97
97
/* *
98
98
* Describes which dtype & dim order specialized kernel to be bound to an
99
- * operator. If `is_fallback_` is true, it means this kernel can be used as a
100
- * fallback, if false, it means this kernel can only be used if all the
101
- * `TensorMeta` are matched. Fallback means this kernel will be used for
102
- * all input tensor dtypes and dim orders, if the specialized kernel is not
103
- * registered.
99
+ * operator.
104
100
*
105
- * The format of a kernel key data is a string:
106
- * "v<version>/<tensor_meta>|<tensor_meta>..."
107
- * Size: Up to 691 1 1 1 (42 +1) * 16
108
- * Assuming max number of tensors is 16 ^
109
- * Kernel key version is v1 for now. If the kernel key format changes,
110
- * update the version to avoid breaking pre-existing kernel keys.
111
- * Example: v1/7;0,1,2,3
112
- * The kernel key has only one tensor: a double tensor with dimension 0, 1, 2, 3
101
+ * Kernel key data is a string with the format:
102
+ *
103
+ * "v<version>/<tensor_meta>|<tensor_meta>..."
104
+ *
105
+ * The version is v1 for now. If the kernel key format changes, update the
106
+ * version to avoid breaking pre-existing kernel keys.
113
107
*
114
108
* Each tensor_meta has the following format: "<dtype>;<dim_order,...>"
115
- * Size: Up to 42 1-2 1 24 (1 byte for 0-9; 2
116
- * for 10-15) + 15 commas Assuming that the max number of dims is 16 ^ Example:
117
- * 7;0,1,2,3 for [double; 0, 1, 2, 3]
109
+ *
110
+ * Example kernel key data: "v1/7;0,1,2,3|1;0,1,2,3,4,5,6,7"
111
+ *
112
+ * This has two tensors: the first with dtype=7 and dim order 0,1,2,3, and the
113
+ * second with dtype=1 and dim order 0,1,2,3,4,5,6,7.
118
114
*
119
115
* IMPORTANT:
120
116
* Users should not construct a kernel key manually. Instead, it should be
121
117
* generated from kernel yaml.
122
118
*/
123
119
struct KernelKey {
124
120
public:
121
+ /* *
122
+ * Creates a fallback (non-specialized) kernel key: this kernel can be used
123
+ * for all input tensor dtypes and dim orders if the specialized kernel is not
124
+ * registered.
125
+ */
125
126
KernelKey () : is_fallback_(true ) {}
126
127
128
+ /* *
129
+ * Creates a specialized (non-fallback) kernel key that matches a specific
130
+ * set of input tensor dtypes and dim orders. See the class comment for the
131
+ * expected format of `kernel_key_data`.
132
+ */
127
133
/* implicit */ KernelKey(const char * kernel_key_data)
128
134
: kernel_key_data_(kernel_key_data), is_fallback_(false ) {}
129
135
130
- constexpr static int MAX_SIZE = 691 ;
131
-
132
136
bool operator ==(const KernelKey& other) const {
133
137
return this ->equals (other);
134
138
}
@@ -144,7 +148,7 @@ struct KernelKey {
144
148
if (is_fallback_) {
145
149
return true ;
146
150
}
147
- return strncmp (kernel_key_data_, other.kernel_key_data_ , MAX_SIZE ) == 0 ;
151
+ return strcmp (kernel_key_data_, other.kernel_key_data_ ) == 0 ;
148
152
}
149
153
150
154
bool is_fallback () const {
@@ -194,7 +198,23 @@ struct Kernel {
194
198
};
195
199
196
200
namespace internal {
197
- void make_kernel_key_string (Span<const TensorMeta> key, char * buf);
201
+
202
+ /* *
203
+ * A make_kernel_key_string buffer size that is large enough to hold a kernel
204
+ * key string with 16 tensors of 16 dimensions, plus the trailing NUL byte.
205
+ */
206
+ constexpr size_t kKernelKeyBufSize = 659 ;
207
+
208
+ /* *
209
+ * Given the list of input tensor dtypes + dim orders, writes the kernel key
210
+ * string into the buffer. Returns an error if the buffer is too small or if the
211
+ * tensors cannot be represented as a valid key string.
212
+ */
213
+ Error make_kernel_key_string (
214
+ Span<const TensorMeta> key,
215
+ char * buf,
216
+ size_t buf_size);
217
+
198
218
} // namespace internal
199
219
200
220
/* *
0 commit comments