Skip to content

Commit 70143a2

Browse files
authored
Add safety checks when rendering kernel key strings
Differential Revision: D69324821 Pull Request resolved: #8327
1 parent 160e5b6 commit 70143a2

File tree

4 files changed

+384
-91
lines changed

4 files changed

+384
-91
lines changed

runtime/kernel/operator_registry.cpp

+90-24
Original file line numberDiff line numberDiff line change
@@ -114,44 +114,106 @@ Error register_kernels(const Span<const Kernel> kernels) {
114114
}
115115

116116
namespace {
117-
int copy_char_as_number_to_buf(char num, char* buf) {
118-
if ((char)num < 10) {
117+
/**
118+
* Writes `num` as a decimal string to `buf` and returns the number of bytes
119+
* written. Returns -1 if `buf` is too small or if `num` is not supported.
120+
*/
121+
int copy_char_as_number_to_buf(int num, char* buf, size_t buf_size) {
122+
if (num < 0) {
123+
return -1;
124+
}
125+
if (num < 10) {
126+
if (buf_size < 1) {
127+
return -1;
128+
}
119129
*buf = '0' + (char)num;
120-
buf += 1;
121130
return 1;
122-
} else {
123-
*buf = '0' + ((char)num) / 10;
124-
buf += 1;
131+
}
132+
if (num < 100) {
133+
if (buf_size < 2) {
134+
return -1;
135+
}
136+
*buf++ = '0' + ((char)num) / 10;
125137
*buf = '0' + ((char)num) % 10;
126-
buf += 1;
127138
return 2;
128139
}
140+
return -1;
129141
}
130142
} // namespace
131143

132144
namespace internal {
133-
void make_kernel_key_string(Span<const TensorMeta> key, char* buf) {
145+
Error make_kernel_key_string(
146+
Span<const TensorMeta> key,
147+
char* buf,
148+
size_t buf_size) {
134149
if (key.empty()) {
135-
// If no tensor is present in an op, kernel key does not apply
136-
return;
150+
// If no tensor is present in an op, kernel key does not apply.
151+
if (buf_size > 0) {
152+
buf[0] = '\0';
153+
}
154+
return Error::Ok;
137155
}
138-
strncpy(buf, "v1/", 3);
156+
157+
// Reserve one byte for null terminator.
158+
if (buf_size < 1) {
159+
return Error::InvalidArgument;
160+
}
161+
buf_size -= 1;
162+
163+
// Add prefix.
164+
if (buf_size < 3) {
165+
return Error::InvalidArgument;
166+
}
167+
memcpy(buf, "v1/", 3);
139168
buf += 3;
169+
buf_size -= 3;
170+
171+
// Add tensor meta.
140172
for (size_t i = 0; i < key.size(); i++) {
141173
auto& meta = key[i];
142-
buf += copy_char_as_number_to_buf((char)meta.dtype_, buf);
143-
*buf = ';';
144-
buf += 1;
174+
175+
// Add dtype.
176+
int n = copy_char_as_number_to_buf((int)meta.dtype_, buf, buf_size);
177+
if (n < 0) {
178+
return Error::InvalidArgument;
179+
}
180+
buf += n;
181+
buf_size -= n;
182+
183+
// Add separator between dtype and dim order.
184+
if (buf_size < 1) {
185+
return Error::InvalidArgument;
186+
}
187+
*buf++ = ';';
188+
buf_size -= 1;
189+
190+
// Add dim order.
145191
for (int j = 0; j < meta.dim_order_.size(); j++) {
146-
buf += copy_char_as_number_to_buf((char)meta.dim_order_[j], buf);
147-
if (j != meta.dim_order_.size() - 1) {
148-
*buf = ',';
149-
buf += 1;
192+
n = copy_char_as_number_to_buf((int)meta.dim_order_[j], buf, buf_size);
193+
if (n < 0) {
194+
return Error::InvalidArgument;
195+
}
196+
buf += n;
197+
buf_size -= n;
198+
199+
if (j < meta.dim_order_.size() - 1) {
200+
if (buf_size < 1) {
201+
return Error::InvalidArgument;
202+
}
203+
*buf++ = ',';
204+
buf_size -= 1;
205+
}
206+
}
207+
if (i < key.size() - 1) {
208+
if (buf_size < 1) {
209+
return Error::InvalidArgument;
150210
}
211+
*buf++ = '|';
212+
buf_size -= 1;
151213
}
152-
*buf = (i < (key.size() - 1)) ? '|' : 0x00;
153-
buf += 1;
154214
}
215+
*buf = '\0'; // Space for this was reserved above.
216+
return Error::Ok;
155217
}
156218
} // namespace internal
157219

@@ -164,10 +226,14 @@ bool registry_has_op_function(
164226
Result<OpFunction> get_op_function_from_registry(
165227
const char* name,
166228
Span<const TensorMeta> meta_list) {
167-
// @lint-ignore CLANGTIDY facebook-hte-CArray
168-
char buf[KernelKey::MAX_SIZE] = {0};
169-
internal::make_kernel_key_string(meta_list, buf);
170-
KernelKey kernel_key = KernelKey(buf);
229+
std::array<char, internal::kKernelKeyBufSize> key_string;
230+
Error err = internal::make_kernel_key_string(
231+
meta_list, key_string.data(), key_string.size());
232+
if (err != Error::Ok) {
233+
ET_LOG(Error, "Failed to make kernel key string");
234+
return err;
235+
}
236+
KernelKey kernel_key = KernelKey(key_string.data());
171237

172238
int32_t fallback_idx = -1;
173239
for (size_t idx = 0; idx < num_registered_kernels; idx++) {

runtime/kernel/operator_registry.h

+40-20
Original file line numberDiff line numberDiff line change
@@ -96,39 +96,43 @@ struct TensorMeta {
9696

9797
/**
9898
* Describes which dtype & dim order specialized kernel to be bound to an
99-
* operator. If `is_fallback_` is true, it means this kernel can be used as a
100-
* fallback, if false, it means this kernel can only be used if all the
101-
* `TensorMeta` are matched. Fallback means this kernel will be used for
102-
* all input tensor dtypes and dim orders, if the specialized kernel is not
103-
* registered.
99+
* operator.
104100
*
105-
* The format of a kernel key data is a string:
106-
* "v<version>/<tensor_meta>|<tensor_meta>..."
107-
* Size: Up to 691 1 1 1 (42 +1) * 16
108-
* Assuming max number of tensors is 16 ^
109-
* Kernel key version is v1 for now. If the kernel key format changes,
110-
* update the version to avoid breaking pre-existing kernel keys.
111-
* Example: v1/7;0,1,2,3
112-
* The kernel key has only one tensor: a double tensor with dimension 0, 1, 2, 3
101+
* Kernel key data is a string with the format:
102+
*
103+
* "v<version>/<tensor_meta>|<tensor_meta>..."
104+
*
105+
* The version is v1 for now. If the kernel key format changes, update the
106+
* version to avoid breaking pre-existing kernel keys.
113107
*
114108
* Each tensor_meta has the following format: "<dtype>;<dim_order,...>"
115-
* Size: Up to 42 1-2 1 24 (1 byte for 0-9; 2
116-
* for 10-15) + 15 commas Assuming that the max number of dims is 16 ^ Example:
117-
* 7;0,1,2,3 for [double; 0, 1, 2, 3]
109+
*
110+
* Example kernel key data: "v1/7;0,1,2,3|1;0,1,2,3,4,5,6,7"
111+
*
112+
* This has two tensors: the first with dtype=7 and dim order 0,1,2,3, and the
113+
* second with dtype=1 and dim order 0,1,2,3,4,5,6,7.
118114
*
119115
* IMPORTANT:
120116
* Users should not construct a kernel key manually. Instead, it should be
121117
* generated from kernel yaml.
122118
*/
123119
struct KernelKey {
124120
public:
121+
/**
122+
* Creates a fallback (non-specialized) kernel key: this kernel can be used
123+
* for all input tensor dtypes and dim orders if the specialized kernel is not
124+
* registered.
125+
*/
125126
KernelKey() : is_fallback_(true) {}
126127

128+
/**
129+
* Creates a specialized (non-fallback) kernel key that matches a specific
130+
* set of input tensor dtypes and dim orders. See the class comment for the
131+
* expected format of `kernel_key_data`.
132+
*/
127133
/* implicit */ KernelKey(const char* kernel_key_data)
128134
: kernel_key_data_(kernel_key_data), is_fallback_(false) {}
129135

130-
constexpr static int MAX_SIZE = 691;
131-
132136
bool operator==(const KernelKey& other) const {
133137
return this->equals(other);
134138
}
@@ -144,7 +148,7 @@ struct KernelKey {
144148
if (is_fallback_) {
145149
return true;
146150
}
147-
return strncmp(kernel_key_data_, other.kernel_key_data_, MAX_SIZE) == 0;
151+
return strcmp(kernel_key_data_, other.kernel_key_data_) == 0;
148152
}
149153

150154
bool is_fallback() const {
@@ -194,7 +198,23 @@ struct Kernel {
194198
};
195199

196200
namespace internal {
197-
void make_kernel_key_string(Span<const TensorMeta> key, char* buf);
201+
202+
/**
203+
* A make_kernel_key_string buffer size that is large enough to hold a kernel
204+
* key string with 16 tensors of 16 dimensions, plus the trailing NUL byte.
205+
*/
206+
constexpr size_t kKernelKeyBufSize = 659;
207+
208+
/**
209+
* Given the list of input tensor dtypes + dim orders, writes the kernel key
210+
* string into the buffer. Returns an error if the buffer is too small or if the
211+
* tensors cannot be represented as a valid key string.
212+
*/
213+
Error make_kernel_key_string(
214+
Span<const TensorMeta> key,
215+
char* buf,
216+
size_t buf_size);
217+
198218
} // namespace internal
199219

200220
/**

0 commit comments

Comments
 (0)