Skip to content

Replace pytree_assert with production pytree_check. Remove pytree_unreachable #8302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions extension/pytree/pybindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,10 @@ class PyTree {
} else if (py::isinstance<py::int_>(key)) {
s.key(i) = py::cast<int32_t>(key);
} else {
pytree_assert(false);
throw std::runtime_error(
std::string(
"invalid key in pytree dict; must be int or string but got ") +
std::string(py::str(key.get_type())));
}

flatten_internal(dict[key], leaves, s[i]);
Expand Down Expand Up @@ -175,7 +178,11 @@ class PyTree {
break;
}
case Kind::None:
pytree_assert(false);
[[fallthrough]];
default:
throw std::runtime_error(
std::string("invalid pytree kind ") + std::to_string(int(kind)) +
" in flatten_internal");
}
}

Expand Down Expand Up @@ -221,11 +228,12 @@ class PyTree {
return py::cast(key.as_int()).release();
case Key::Kind::Str:
return py::cast(key.as_str()).release();
case Key::Kind::None:
pytree_assert(false);
default:
throw std::runtime_error(
std::string("invalid key kind ") +
std::to_string(int(key.kind())) +
" in pytree dict; must be int or string");
}
pytree_assert(false);
return py::none();
}();
dict[py_key] = unflatten_internal(spec[i], leaves_it);
}
Expand All @@ -241,7 +249,9 @@ class PyTree {
return py::none();
}
}
pytree_assert(false);
throw std::runtime_error(
std::string("invalid spec kind ") + std::to_string(int(spec.kind())) +
" in unflatten_internal");
}

public:
Expand Down Expand Up @@ -339,12 +349,10 @@ static py::object broadcast_to_and_flatten(
if (kind != top.x_spec_node->kind()) {
return py::none();
}
pytree_assert(top.tree_spec_node->kind() == top.x_spec_node->kind());
const size_t child_num = top.tree_spec_node->size();
if (child_num != top.x_spec_node->size()) {
return py::none();
}
pytree_assert(child_num == top.x_spec_node->size());

size_t x_leaves_offset =
top.x_leaves_offset + top.x_spec_node->leaves_num();
Expand Down
60 changes: 27 additions & 33 deletions extension/pytree/pytree.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ namespace executorch {
namespace extension {
namespace pytree {

inline void pytree_assert(bool must_be_true) {
assert(must_be_true);
inline void pytree_check(bool must_be_true) {
if (!must_be_true) {
throw std::runtime_error("pytree assertion failed");
}
}

#ifdef _MSC_VER
Expand All @@ -37,18 +39,6 @@ inline void pytree_assert(bool must_be_true) {
#define EXECUTORCH_ALWAYS_INLINE inline
#endif

[[noreturn]] EXECUTORCH_ALWAYS_INLINE void pytree_unreachable() {
assert(false);
#if defined(__GNUC__)
__builtin_unreachable();
#elif defined(_MSC_VER)
__assume(0);
#else
while (!0)
;
#endif
}

enum class Kind : uint8_t { List, Tuple, NamedTuple, Dict, Leaf, Custom, None };

using KeyStr = std::string;
Expand Down Expand Up @@ -144,45 +134,45 @@ struct ContainerHandle {
: handle(std::move(c)) {}

void set_leaf(leaf_type* leaf) {
pytree_assert(handle->kind == Kind::Leaf);
pytree_check(handle->kind == Kind::Leaf);
handle->leaf = leaf;
}

operator leaf_type() const {
pytree_assert(handle->kind == Kind::Leaf);
pytree_check(handle->kind == Kind::Leaf);
return *handle->leaf;
}

const leaf_type& leaf() const {
pytree_assert(handle->kind == Kind::Leaf);
pytree_check(handle->kind == Kind::Leaf);
return *handle->leaf;
}
leaf_type& leaf() {
pytree_assert(handle->kind == Kind::Leaf);
pytree_check(handle->kind == Kind::Leaf);
return *handle->leaf;
}

const leaf_type* leaf_ptr() const {
pytree_assert(handle->kind == Kind::Leaf);
pytree_check(handle->kind == Kind::Leaf);
return handle->leaf;
}
leaf_type* leaf_ptr() {
pytree_assert(handle->kind == Kind::Leaf);
pytree_check(handle->kind == Kind::Leaf);
return handle->leaf;
}

const ContainerHandle& operator[](size_t idx) const {
pytree_assert(idx < handle->size);
pytree_check(idx < handle->size);
return handle->items[idx];
}

ContainerHandle& operator[](size_t idx) {
pytree_assert(idx < handle->size);
pytree_check(idx < handle->size);
return handle->items[idx];
}

bool contains(const KeyStr& lookup_key) const {
pytree_assert(isDict());
pytree_check(isDict());
for (size_t i = 0; i < handle->size; ++i) {
if (handle->keys[i] == lookup_key) {
return true;
Expand All @@ -192,13 +182,13 @@ struct ContainerHandle {
}

const ContainerHandle& at(const Key& lookup_key) const {
pytree_assert(isDict());
pytree_check(isDict());
for (size_t i = 0; i < handle->size; ++i) {
if (handle->keys[i] == lookup_key) {
return handle->items[i];
}
}
pytree_unreachable();
throw std::runtime_error("Dict::at lookup failed");
}

const ContainerHandle& at(const KeyInt& lookup_key) const {
Expand All @@ -210,11 +200,11 @@ struct ContainerHandle {
}

const Key& key(size_t idx) const {
pytree_assert(isDict());
pytree_check(isDict());
return handle->keys[idx];
}
Key& key(size_t idx) {
pytree_assert(isDict());
pytree_check(isDict());
return handle->keys[idx];
}

Expand Down Expand Up @@ -399,7 +389,8 @@ StrTreeSpec to_str_internal(const TreeSpec<Aux>& spec) {
s.append(key.as_str());
s.push_back(Config::kDictStrKeyQuote);
} else {
pytree_unreachable();
throw std::runtime_error(
"invalid key in pytree dict; must be int or string");
}
s.push_back(Config::kDictKeyValueSep);
s.append(to_str_internal(spec[i]));
Expand Down Expand Up @@ -475,6 +466,11 @@ struct arr {

inline size_t read_number(const StrTreeSpec& spec, size_t& read_idx) {
size_t num = 0;
if (!isdigit(spec.at(read_idx))) {
throw std::runtime_error(
std::string("expected a digit while decoding pytree, not ") +
spec[read_idx]);
}
while (isdigit(spec.at(read_idx))) {
num = 10 * num + (spec[read_idx] - '0');
read_idx++;
Expand Down Expand Up @@ -583,7 +579,6 @@ TreeSpec<Aux> from_str_internal(
c->keys[child_idx] = spec.substr(read_idx, key_len);
read_idx = key_delim_idx + 2;
} else {
pytree_assert(isdigit(spec[read_idx]));
size_t key = read_number(spec, read_idx);
c->keys[child_idx] = KeyInt(key);
read_idx += 1;
Expand All @@ -604,7 +599,6 @@ TreeSpec<Aux> from_str_internal(
case Config::kLeaf:
return new TreeSpecContainer<Aux>(nullptr);
}
pytree_unreachable();
return new TreeSpecContainer<Aux>(Kind::None);
}

Expand All @@ -616,17 +610,17 @@ struct stack final {
T data[SIZE];

void push(T&& item) {
pytree_assert(size_ < SIZE);
pytree_check(size_ < SIZE);
data[size_++] = std::move(item);
}

T pop() {
pytree_assert(size_ > 0);
pytree_check(size_ > 0);
return data[--size_];
}

T& top() {
pytree_assert(size_ > 0);
pytree_check(size_ > 0);
return data[size_ - 1];
}

Expand Down
Loading