Skip to content

Commit 12aa444

Browse files
committed
Replace pytree_assert with production pytree_check. Remove pytree_unreachable
Pull Request resolved: #7655 When handling untrusted input, it's not appropriate to use debug-only checks; we should be checking in prod as these are not programmer errors. pytree_unreachable was similarly being used for input validation. ghstack-source-id: 262028959 Differential Revision: [D68166301](https://our.internmc.facebook.com/intern/diff/D68166301/)
1 parent 73ad5a6 commit 12aa444

File tree

2 files changed

+34
-42
lines changed

2 files changed

+34
-42
lines changed

extension/pytree/pybindings.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ class PyTree {
145145
} else if (py::isinstance<py::int_>(key)) {
146146
s.key(i) = py::cast<int32_t>(key);
147147
} else {
148-
pytree_assert(false);
148+
throw std::runtime_err(
149+
"invalid key in pytree dict; must be int or string");
149150
}
150151

151152
flatten_internal(dict[key], leaves, s[i]);
@@ -175,7 +176,9 @@ class PyTree {
175176
break;
176177
}
177178
case Kind::None:
178-
pytree_assert(false);
179+
[[fallthrough]];
180+
default:
181+
throw std::runtime_error("invalid pytree kind in flatten_internal");
179182
}
180183
}
181184

@@ -221,11 +224,10 @@ class PyTree {
221224
return py::cast(key.as_int()).release();
222225
case Key::Kind::Str:
223226
return py::cast(key.as_str()).release();
224-
case Key::Kind::None:
225-
pytree_assert(false);
227+
default:
228+
throw std::runtime_error(
229+
"invalid key in pytree dict; must be int or string");
226230
}
227-
pytree_assert(false);
228-
return py::none();
229231
}();
230232
dict[py_key] = unflatten_internal(spec[i], leaves_it);
231233
}
@@ -241,7 +243,7 @@ class PyTree {
241243
return py::none();
242244
}
243245
}
244-
pytree_assert(false);
246+
throw std::runtime_error("invalid spec kind in unflatten_internal");
245247
}
246248

247249
public:
@@ -339,12 +341,10 @@ static py::object broadcast_to_and_flatten(
339341
if (kind != top.x_spec_node->kind()) {
340342
return py::none();
341343
}
342-
pytree_assert(top.tree_spec_node->kind() == top.x_spec_node->kind());
343344
const size_t child_num = top.tree_spec_node->size();
344345
if (child_num != top.x_spec_node->size()) {
345346
return py::none();
346347
}
347-
pytree_assert(child_num == top.x_spec_node->size());
348348

349349
size_t x_leaves_offset =
350350
top.x_leaves_offset + top.x_spec_node->leaves_num();

extension/pytree/pytree.h

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@ namespace executorch {
2424
namespace extension {
2525
namespace pytree {
2626

27-
inline void pytree_assert(bool must_be_true) {
28-
assert(must_be_true);
27+
inline void pytree_check(bool must_be_true) {
28+
if (!must_be_true) {
29+
throw std::runtime_error("pytree assertion failed");
30+
}
2931
}
3032

3133
#ifdef _MSC_VER
@@ -36,18 +38,6 @@ inline void pytree_assert(bool must_be_true) {
3638
#define EXECUTORCH_ALWAYS_INLINE inline
3739
#endif
3840

39-
[[noreturn]] EXECUTORCH_ALWAYS_INLINE void pytree_unreachable() {
40-
assert(false);
41-
#if defined(__GNUC__)
42-
__builtin_unreachable();
43-
#elif defined(_MSC_VER)
44-
__assume(0);
45-
#else
46-
while (!0)
47-
;
48-
#endif
49-
}
50-
5141
enum class Kind : uint8_t { List, Tuple, NamedTuple, Dict, Leaf, Custom, None };
5242

5343
using KeyStr = std::string;
@@ -143,45 +133,45 @@ struct ContainerHandle {
143133
: handle(std::move(c)) {}
144134

145135
void set_leaf(leaf_type* leaf) {
146-
pytree_assert(handle->kind == Kind::Leaf);
136+
pytree_check(handle->kind == Kind::Leaf);
147137
handle->leaf = leaf;
148138
}
149139

150140
operator leaf_type() const {
151-
pytree_assert(handle->kind == Kind::Leaf);
141+
pytree_check(handle->kind == Kind::Leaf);
152142
return *handle->leaf;
153143
}
154144

155145
const leaf_type& leaf() const {
156-
pytree_assert(handle->kind == Kind::Leaf);
146+
pytree_check(handle->kind == Kind::Leaf);
157147
return *handle->leaf;
158148
}
159149
leaf_type& leaf() {
160-
pytree_assert(handle->kind == Kind::Leaf);
150+
pytree_check(handle->kind == Kind::Leaf);
161151
return *handle->leaf;
162152
}
163153

164154
const leaf_type* leaf_ptr() const {
165-
pytree_assert(handle->kind == Kind::Leaf);
155+
pytree_check(handle->kind == Kind::Leaf);
166156
return handle->leaf;
167157
}
168158
leaf_type* leaf_ptr() {
169-
pytree_assert(handle->kind == Kind::Leaf);
159+
pytree_check(handle->kind == Kind::Leaf);
170160
return handle->leaf;
171161
}
172162

173163
const ContainerHandle& operator[](size_t idx) const {
174-
pytree_assert(idx < handle->size);
164+
pytree_check(idx < handle->size);
175165
return handle->items[idx];
176166
}
177167

178168
ContainerHandle& operator[](size_t idx) {
179-
pytree_assert(idx < handle->size);
169+
pytree_check(idx < handle->size);
180170
return handle->items[idx];
181171
}
182172

183173
bool contains(const KeyStr& lookup_key) const {
184-
pytree_assert(isDict());
174+
pytree_check(isDict());
185175
for (size_t i = 0; i < handle->size; ++i) {
186176
if (handle->keys[i] == lookup_key) {
187177
return true;
@@ -191,13 +181,13 @@ struct ContainerHandle {
191181
}
192182

193183
const ContainerHandle& at(const Key& lookup_key) const {
194-
pytree_assert(isDict());
184+
pytree_check(isDict());
195185
for (size_t i = 0; i < handle->size; ++i) {
196186
if (handle->keys[i] == lookup_key) {
197187
return handle->items[i];
198188
}
199189
}
200-
pytree_unreachable();
190+
throw std::runtime_error("Dict::at lookup failed");
201191
}
202192

203193
const ContainerHandle& at(const KeyInt& lookup_key) const {
@@ -209,11 +199,11 @@ struct ContainerHandle {
209199
}
210200

211201
const Key& key(size_t idx) const {
212-
pytree_assert(isDict());
202+
pytree_check(isDict());
213203
return handle->keys[idx];
214204
}
215205
Key& key(size_t idx) {
216-
pytree_assert(isDict());
206+
pytree_check(isDict());
217207
return handle->keys[idx];
218208
}
219209

@@ -398,7 +388,8 @@ StrTreeSpec to_str_internal(const TreeSpec<Aux>& spec) {
398388
s.append(key.as_str());
399389
s.push_back(Config::kDictStrKeyQuote);
400390
} else {
401-
pytree_unreachable();
391+
throw std::runtime_error(
392+
"invalid key in pytree dict; must be int or string");
402393
}
403394
s.push_back(Config::kDictKeyValueSep);
404395
s.append(to_str_internal(spec[i]));
@@ -472,6 +463,9 @@ struct arr {
472463

473464
inline size_t read_number(const StrTreeSpec& spec, size_t& read_idx) {
474465
size_t num = 0;
466+
if (!isdigit(spec.at(read_idx))) {
467+
throw std::runtime_error("expected a number while decoding pytree");
468+
}
475469
while (isdigit(spec.at(read_idx))) {
476470
num = 10 * num + (spec[read_idx] - '0');
477471
read_idx++;
@@ -577,7 +571,6 @@ TreeSpec<Aux> from_str_internal(
577571
c->keys[child_idx] = spec.substr(read_idx, key_len);
578572
read_idx = key_delim_idx + 2;
579573
} else {
580-
pytree_assert(isdigit(spec[read_idx]));
581574
size_t key = read_number(spec, read_idx);
582575
c->keys[child_idx] = KeyInt(key);
583576
read_idx += 1;
@@ -598,7 +591,6 @@ TreeSpec<Aux> from_str_internal(
598591
case Config::kLeaf:
599592
return new TreeSpecContainer<Aux>(nullptr);
600593
}
601-
pytree_unreachable();
602594
return new TreeSpecContainer<Aux>(Kind::None);
603595
}
604596

@@ -610,17 +602,17 @@ struct stack final {
610602
T data[SIZE];
611603

612604
void push(T&& item) {
613-
pytree_assert(size_ < SIZE);
605+
pytree_check(size_ < SIZE);
614606
data[size_++] = std::move(item);
615607
}
616608

617609
T pop() {
618-
pytree_assert(size_ > 0);
610+
pytree_check(size_ > 0);
619611
return data[--size_];
620612
}
621613

622614
T& top() {
623-
pytree_assert(size_ > 0);
615+
pytree_check(size_ > 0);
624616
return data[size_ - 1];
625617
}
626618

0 commit comments

Comments
 (0)