Skip to content

Commit 0100a37

Browse files
pytorchbotGithub Executorch
and
Github Executorch
authored
Replace pytree_assert with production pytree_check. Remove pytree_unreachable
Pull Request resolved: #7655 When handling untrusted input, it's not appropriate to use debug-only checks; we should be checking in prod as these are not programmer errors. pytree_unreachable was similarly being used for input validation. ghstack-source-id: 265152270 Differential Revision: [D68166301](https://our.internmc.facebook.com/intern/diff/D68166301/) Co-authored-by: Github Executorch <[email protected]>
1 parent 391bd68 commit 0100a37

File tree

2 files changed

+44
-42
lines changed

2 files changed

+44
-42
lines changed

extension/pytree/pybindings.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,10 @@ class PyTree {
145145
} else if (py::isinstance<py::int_>(key)) {
146146
s.key(i) = py::cast<int32_t>(key);
147147
} else {
148-
pytree_assert(false);
148+
throw std::runtime_error(
149+
std::string(
150+
"invalid key in pytree dict; must be int or string but got ") +
151+
std::string(py::str(key.get_type())));
149152
}
150153

151154
flatten_internal(dict[key], leaves, s[i]);
@@ -175,7 +178,11 @@ class PyTree {
175178
break;
176179
}
177180
case Kind::None:
178-
pytree_assert(false);
181+
[[fallthrough]];
182+
default:
183+
throw std::runtime_error(
184+
std::string("invalid pytree kind ") + std::to_string(int(kind)) +
185+
" in flatten_internal");
179186
}
180187
}
181188

@@ -221,11 +228,12 @@ class PyTree {
221228
return py::cast(key.as_int()).release();
222229
case Key::Kind::Str:
223230
return py::cast(key.as_str()).release();
224-
case Key::Kind::None:
225-
pytree_assert(false);
231+
default:
232+
throw std::runtime_error(
233+
std::string("invalid key kind ") +
234+
std::to_string(int(key.kind())) +
235+
" in pytree dict; must be int or string");
226236
}
227-
pytree_assert(false);
228-
return py::none();
229237
}();
230238
dict[py_key] = unflatten_internal(spec[i], leaves_it);
231239
}
@@ -241,7 +249,9 @@ class PyTree {
241249
return py::none();
242250
}
243251
}
244-
pytree_assert(false);
252+
throw std::runtime_error(
253+
std::string("invalid spec kind ") + std::to_string(int(spec.kind())) +
254+
" in unflatten_internal");
245255
}
246256

247257
public:
@@ -339,12 +349,10 @@ static py::object broadcast_to_and_flatten(
339349
if (kind != top.x_spec_node->kind()) {
340350
return py::none();
341351
}
342-
pytree_assert(top.tree_spec_node->kind() == top.x_spec_node->kind());
343352
const size_t child_num = top.tree_spec_node->size();
344353
if (child_num != top.x_spec_node->size()) {
345354
return py::none();
346355
}
347-
pytree_assert(child_num == top.x_spec_node->size());
348356

349357
size_t x_leaves_offset =
350358
top.x_leaves_offset + top.x_spec_node->leaves_num();

extension/pytree/pytree.h

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@ namespace executorch {
2525
namespace extension {
2626
namespace pytree {
2727

28-
inline void pytree_assert(bool must_be_true) {
29-
assert(must_be_true);
28+
inline void pytree_check(bool must_be_true) {
29+
if (!must_be_true) {
30+
throw std::runtime_error("pytree assertion failed");
31+
}
3032
}
3133

3234
#ifdef _MSC_VER
@@ -37,18 +39,6 @@ inline void pytree_assert(bool must_be_true) {
3739
#define EXECUTORCH_ALWAYS_INLINE inline
3840
#endif
3941

40-
[[noreturn]] EXECUTORCH_ALWAYS_INLINE void pytree_unreachable() {
41-
assert(false);
42-
#if defined(__GNUC__)
43-
__builtin_unreachable();
44-
#elif defined(_MSC_VER)
45-
__assume(0);
46-
#else
47-
while (!0)
48-
;
49-
#endif
50-
}
51-
5242
enum class Kind : uint8_t { List, Tuple, NamedTuple, Dict, Leaf, Custom, None };
5343

5444
using KeyStr = std::string;
@@ -144,45 +134,45 @@ struct ContainerHandle {
144134
: handle(std::move(c)) {}
145135

146136
void set_leaf(leaf_type* leaf) {
147-
pytree_assert(handle->kind == Kind::Leaf);
137+
pytree_check(handle->kind == Kind::Leaf);
148138
handle->leaf = leaf;
149139
}
150140

151141
operator leaf_type() const {
152-
pytree_assert(handle->kind == Kind::Leaf);
142+
pytree_check(handle->kind == Kind::Leaf);
153143
return *handle->leaf;
154144
}
155145

156146
const leaf_type& leaf() const {
157-
pytree_assert(handle->kind == Kind::Leaf);
147+
pytree_check(handle->kind == Kind::Leaf);
158148
return *handle->leaf;
159149
}
160150
leaf_type& leaf() {
161-
pytree_assert(handle->kind == Kind::Leaf);
151+
pytree_check(handle->kind == Kind::Leaf);
162152
return *handle->leaf;
163153
}
164154

165155
const leaf_type* leaf_ptr() const {
166-
pytree_assert(handle->kind == Kind::Leaf);
156+
pytree_check(handle->kind == Kind::Leaf);
167157
return handle->leaf;
168158
}
169159
leaf_type* leaf_ptr() {
170-
pytree_assert(handle->kind == Kind::Leaf);
160+
pytree_check(handle->kind == Kind::Leaf);
171161
return handle->leaf;
172162
}
173163

174164
const ContainerHandle& operator[](size_t idx) const {
175-
pytree_assert(idx < handle->size);
165+
pytree_check(idx < handle->size);
176166
return handle->items[idx];
177167
}
178168

179169
ContainerHandle& operator[](size_t idx) {
180-
pytree_assert(idx < handle->size);
170+
pytree_check(idx < handle->size);
181171
return handle->items[idx];
182172
}
183173

184174
bool contains(const KeyStr& lookup_key) const {
185-
pytree_assert(isDict());
175+
pytree_check(isDict());
186176
for (size_t i = 0; i < handle->size; ++i) {
187177
if (handle->keys[i] == lookup_key) {
188178
return true;
@@ -192,13 +182,13 @@ struct ContainerHandle {
192182
}
193183

194184
const ContainerHandle& at(const Key& lookup_key) const {
195-
pytree_assert(isDict());
185+
pytree_check(isDict());
196186
for (size_t i = 0; i < handle->size; ++i) {
197187
if (handle->keys[i] == lookup_key) {
198188
return handle->items[i];
199189
}
200190
}
201-
pytree_unreachable();
191+
throw std::runtime_error("Dict::at lookup failed");
202192
}
203193

204194
const ContainerHandle& at(const KeyInt& lookup_key) const {
@@ -210,11 +200,11 @@ struct ContainerHandle {
210200
}
211201

212202
const Key& key(size_t idx) const {
213-
pytree_assert(isDict());
203+
pytree_check(isDict());
214204
return handle->keys[idx];
215205
}
216206
Key& key(size_t idx) {
217-
pytree_assert(isDict());
207+
pytree_check(isDict());
218208
return handle->keys[idx];
219209
}
220210

@@ -399,7 +389,8 @@ StrTreeSpec to_str_internal(const TreeSpec<Aux>& spec) {
399389
s.append(key.as_str());
400390
s.push_back(Config::kDictStrKeyQuote);
401391
} else {
402-
pytree_unreachable();
392+
throw std::runtime_error(
393+
"invalid key in pytree dict; must be int or string");
403394
}
404395
s.push_back(Config::kDictKeyValueSep);
405396
s.append(to_str_internal(spec[i]));
@@ -475,6 +466,11 @@ struct arr {
475466

476467
inline size_t read_number(const StrTreeSpec& spec, size_t& read_idx) {
477468
size_t num = 0;
469+
if (!isdigit(spec.at(read_idx))) {
470+
throw std::runtime_error(
471+
std::string("expected a digit while decoding pytree, not ") +
472+
spec[read_idx]);
473+
}
478474
while (isdigit(spec.at(read_idx))) {
479475
num = 10 * num + (spec[read_idx] - '0');
480476
read_idx++;
@@ -583,7 +579,6 @@ TreeSpec<Aux> from_str_internal(
583579
c->keys[child_idx] = spec.substr(read_idx, key_len);
584580
read_idx = key_delim_idx + 2;
585581
} else {
586-
pytree_assert(isdigit(spec[read_idx]));
587582
size_t key = read_number(spec, read_idx);
588583
c->keys[child_idx] = KeyInt(key);
589584
read_idx += 1;
@@ -604,7 +599,6 @@ TreeSpec<Aux> from_str_internal(
604599
case Config::kLeaf:
605600
return new TreeSpecContainer<Aux>(nullptr);
606601
}
607-
pytree_unreachable();
608602
return new TreeSpecContainer<Aux>(Kind::None);
609603
}
610604

@@ -616,17 +610,17 @@ struct stack final {
616610
T data[SIZE];
617611

618612
void push(T&& item) {
619-
pytree_assert(size_ < SIZE);
613+
pytree_check(size_ < SIZE);
620614
data[size_++] = std::move(item);
621615
}
622616

623617
T pop() {
624-
pytree_assert(size_ > 0);
618+
pytree_check(size_ > 0);
625619
return data[--size_];
626620
}
627621

628622
T& top() {
629-
pytree_assert(size_ > 0);
623+
pytree_check(size_ > 0);
630624
return data[size_ - 1];
631625
}
632626

0 commit comments

Comments
 (0)