Skip to content

Commit 655bf86

Browse files
committed
Replace pytree_assert with production pytree_check. Remove pytree_unreachable
Pull Request resolved: #7655 When handling untrusted input, it's not appropriate to use debug-only checks; we should be checking in prod as these are not programmer errors. pytree_unreachable was similarly being used for input validation. ghstack-source-id: 262806899 Differential Revision: [D68166301](https://our.internmc.facebook.com/intern/diff/D68166301/)
1 parent 81b8ced commit 655bf86

File tree

2 files changed

+44
-42
lines changed

2 files changed

+44
-42
lines changed

extension/pytree/pybindings.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,10 @@ class PyTree {
145145
} else if (py::isinstance<py::int_>(key)) {
146146
s.key(i) = py::cast<int32_t>(key);
147147
} else {
148-
pytree_assert(false);
148+
throw std::runtime_error(
149+
std::string(
150+
"invalid key in pytree dict; must be int or string but got ") +
151+
std::string(py::str(key.get_type())));
149152
}
150153

151154
flatten_internal(dict[key], leaves, s[i]);
@@ -175,7 +178,11 @@ class PyTree {
175178
break;
176179
}
177180
case Kind::None:
178-
pytree_assert(false);
181+
[[fallthrough]];
182+
default:
183+
throw std::runtime_error(
184+
std::string("invalid pytree kind ") + std::to_string(int(kind)) +
185+
" in flatten_internal");
179186
}
180187
}
181188

@@ -221,11 +228,12 @@ class PyTree {
221228
return py::cast(key.as_int()).release();
222229
case Key::Kind::Str:
223230
return py::cast(key.as_str()).release();
224-
case Key::Kind::None:
225-
pytree_assert(false);
231+
default:
232+
throw std::runtime_error(
233+
std::string("invalid key kind ") +
234+
std::to_string(int(key.kind())) +
235+
" in pytree dict; must be int or string");
226236
}
227-
pytree_assert(false);
228-
return py::none();
229237
}();
230238
dict[py_key] = unflatten_internal(spec[i], leaves_it);
231239
}
@@ -241,7 +249,9 @@ class PyTree {
241249
return py::none();
242250
}
243251
}
244-
pytree_assert(false);
252+
throw std::runtime_error(
253+
std::string("invalid spec kind ") + std::to_string(int(spec.kind())) +
254+
" in unflatten_internal");
245255
}
246256

247257
public:
@@ -339,12 +349,10 @@ static py::object broadcast_to_and_flatten(
339349
if (kind != top.x_spec_node->kind()) {
340350
return py::none();
341351
}
342-
pytree_assert(top.tree_spec_node->kind() == top.x_spec_node->kind());
343352
const size_t child_num = top.tree_spec_node->size();
344353
if (child_num != top.x_spec_node->size()) {
345354
return py::none();
346355
}
347-
pytree_assert(child_num == top.x_spec_node->size());
348356

349357
size_t x_leaves_offset =
350358
top.x_leaves_offset + top.x_spec_node->leaves_num();

extension/pytree/pytree.h

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@ namespace executorch {
2424
namespace extension {
2525
namespace pytree {
2626

27-
inline void pytree_assert(bool must_be_true) {
28-
assert(must_be_true);
27+
inline void pytree_check(bool must_be_true) {
28+
if (!must_be_true) {
29+
throw std::runtime_error("pytree assertion failed");
30+
}
2931
}
3032

3133
#ifdef _MSC_VER
@@ -36,18 +38,6 @@ inline void pytree_assert(bool must_be_true) {
3638
#define EXECUTORCH_ALWAYS_INLINE inline
3739
#endif
3840

39-
[[noreturn]] EXECUTORCH_ALWAYS_INLINE void pytree_unreachable() {
40-
assert(false);
41-
#if defined(__GNUC__)
42-
__builtin_unreachable();
43-
#elif defined(_MSC_VER)
44-
__assume(0);
45-
#else
46-
while (!0)
47-
;
48-
#endif
49-
}
50-
5141
enum class Kind : uint8_t { List, Tuple, NamedTuple, Dict, Leaf, Custom, None };
5242

5343
using KeyStr = std::string;
@@ -143,45 +133,45 @@ struct ContainerHandle {
143133
: handle(std::move(c)) {}
144134

145135
void set_leaf(leaf_type* leaf) {
146-
pytree_assert(handle->kind == Kind::Leaf);
136+
pytree_check(handle->kind == Kind::Leaf);
147137
handle->leaf = leaf;
148138
}
149139

150140
operator leaf_type() const {
151-
pytree_assert(handle->kind == Kind::Leaf);
141+
pytree_check(handle->kind == Kind::Leaf);
152142
return *handle->leaf;
153143
}
154144

155145
const leaf_type& leaf() const {
156-
pytree_assert(handle->kind == Kind::Leaf);
146+
pytree_check(handle->kind == Kind::Leaf);
157147
return *handle->leaf;
158148
}
159149
leaf_type& leaf() {
160-
pytree_assert(handle->kind == Kind::Leaf);
150+
pytree_check(handle->kind == Kind::Leaf);
161151
return *handle->leaf;
162152
}
163153

164154
const leaf_type* leaf_ptr() const {
165-
pytree_assert(handle->kind == Kind::Leaf);
155+
pytree_check(handle->kind == Kind::Leaf);
166156
return handle->leaf;
167157
}
168158
leaf_type* leaf_ptr() {
169-
pytree_assert(handle->kind == Kind::Leaf);
159+
pytree_check(handle->kind == Kind::Leaf);
170160
return handle->leaf;
171161
}
172162

173163
const ContainerHandle& operator[](size_t idx) const {
174-
pytree_assert(idx < handle->size);
164+
pytree_check(idx < handle->size);
175165
return handle->items[idx];
176166
}
177167

178168
ContainerHandle& operator[](size_t idx) {
179-
pytree_assert(idx < handle->size);
169+
pytree_check(idx < handle->size);
180170
return handle->items[idx];
181171
}
182172

183173
bool contains(const KeyStr& lookup_key) const {
184-
pytree_assert(isDict());
174+
pytree_check(isDict());
185175
for (size_t i = 0; i < handle->size; ++i) {
186176
if (handle->keys[i] == lookup_key) {
187177
return true;
@@ -191,13 +181,13 @@ struct ContainerHandle {
191181
}
192182

193183
const ContainerHandle& at(const Key& lookup_key) const {
194-
pytree_assert(isDict());
184+
pytree_check(isDict());
195185
for (size_t i = 0; i < handle->size; ++i) {
196186
if (handle->keys[i] == lookup_key) {
197187
return handle->items[i];
198188
}
199189
}
200-
pytree_unreachable();
190+
throw std::runtime_error("Dict::at lookup failed");
201191
}
202192

203193
const ContainerHandle& at(const KeyInt& lookup_key) const {
@@ -209,11 +199,11 @@ struct ContainerHandle {
209199
}
210200

211201
const Key& key(size_t idx) const {
212-
pytree_assert(isDict());
202+
pytree_check(isDict());
213203
return handle->keys[idx];
214204
}
215205
Key& key(size_t idx) {
216-
pytree_assert(isDict());
206+
pytree_check(isDict());
217207
return handle->keys[idx];
218208
}
219209

@@ -398,7 +388,8 @@ StrTreeSpec to_str_internal(const TreeSpec<Aux>& spec) {
398388
s.append(key.as_str());
399389
s.push_back(Config::kDictStrKeyQuote);
400390
} else {
401-
pytree_unreachable();
391+
throw std::runtime_error(
392+
"invalid key in pytree dict; must be int or string");
402393
}
403394
s.push_back(Config::kDictKeyValueSep);
404395
s.append(to_str_internal(spec[i]));
@@ -474,6 +465,11 @@ struct arr {
474465

475466
inline size_t read_number(const StrTreeSpec& spec, size_t& read_idx) {
476467
size_t num = 0;
468+
if (!isdigit(spec.at(read_idx))) {
469+
throw std::runtime_error(
470+
std::string("expected a digit while decoding pytree, not ") +
471+
spec[read_idx]);
472+
}
477473
while (isdigit(spec.at(read_idx))) {
478474
num = 10 * num + (spec[read_idx] - '0');
479475
read_idx++;
@@ -582,7 +578,6 @@ TreeSpec<Aux> from_str_internal(
582578
c->keys[child_idx] = spec.substr(read_idx, key_len);
583579
read_idx = key_delim_idx + 2;
584580
} else {
585-
pytree_assert(isdigit(spec[read_idx]));
586581
size_t key = read_number(spec, read_idx);
587582
c->keys[child_idx] = KeyInt(key);
588583
read_idx += 1;
@@ -603,7 +598,6 @@ TreeSpec<Aux> from_str_internal(
603598
case Config::kLeaf:
604599
return new TreeSpecContainer<Aux>(nullptr);
605600
}
606-
pytree_unreachable();
607601
return new TreeSpecContainer<Aux>(Kind::None);
608602
}
609603

@@ -615,17 +609,17 @@ struct stack final {
615609
T data[SIZE];
616610

617611
void push(T&& item) {
618-
pytree_assert(size_ < SIZE);
612+
pytree_check(size_ < SIZE);
619613
data[size_++] = std::move(item);
620614
}
621615

622616
T pop() {
623-
pytree_assert(size_ > 0);
617+
pytree_check(size_ > 0);
624618
return data[--size_];
625619
}
626620

627621
T& top() {
628-
pytree_assert(size_ > 0);
622+
pytree_check(size_ > 0);
629623
return data[size_ - 1];
630624
}
631625

0 commit comments

Comments
 (0)