Skip to content

add a bunch of bounds checking to pytree #8301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 51 additions & 14 deletions extension/pytree/pytree.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <cstdint>
#include <cstring>
#include <memory>
#include <stdexcept>
#include <string>
#include <variant>

Expand Down Expand Up @@ -60,7 +61,7 @@ struct Key {
std::variant<std::monostate, KeyInt, KeyStr> repr_;

public:
Key() {}
Key() = default;
/*implicit*/ Key(KeyInt key) : repr_(key) {}
/*implicit*/ Key(KeyStr key) : repr_(std::move(key)) {}

Expand Down Expand Up @@ -131,7 +132,7 @@ struct ContainerHandle {
using leaf_type = T;
std::unique_ptr<container_type> handle;

ContainerHandle() {}
ContainerHandle() = default;

template <typename... Args>
ContainerHandle(Args... args)
Expand Down Expand Up @@ -427,6 +428,22 @@ struct arr {
return data_[idx];
}

T& at(size_t idx) {
if (idx >= size()) {
throw std::out_of_range(
"bounds check failed in pytree arr at index " + std::to_string(idx));
}
return data_[idx];
}

const T& at(size_t idx) const {
if (idx >= size()) {
throw std::out_of_range(
"bounds check failed in pytree arr at index " + std::to_string(idx));
}
return data_[idx];
}

inline T* data() {
return data_.get();
}
Expand Down Expand Up @@ -458,7 +475,7 @@ struct arr {

inline size_t read_number(const StrTreeSpec& spec, size_t& read_idx) {
size_t num = 0;
while (isdigit(spec[read_idx])) {
while (isdigit(spec.at(read_idx))) {
num = 10 * num + (spec[read_idx] - '0');
read_idx++;
}
Expand All @@ -470,19 +487,22 @@ inline arr<size_t> read_node_layout(const StrTreeSpec& spec, size_t& read_idx) {
arr<size_t> ret(child_num);

size_t child_idx = 0;
while (spec[read_idx] == Config::kChildrenDataSep) {
while (spec.at(read_idx) == Config::kChildrenDataSep) {
++read_idx;
ret[child_idx++] = read_number(spec, read_idx);
ret.at(child_idx++) = read_number(spec, read_idx);
}
return ret;
}

// spec_data comes from pre_parse, which guarantees 1)
// spec_data.size() == spec.size() and 2) contents of spec_data are
// in-bounds indices for spec, so we omit bounds checks for spec_data.
template <typename Aux>
TreeSpec<Aux> from_str_internal(
const StrTreeSpec& spec,
size_t read_idx,
const arr<size_t>& spec_data) {
const auto kind_char = spec[read_idx];
const auto kind_char = spec.at(read_idx);
switch (kind_char) {
case Config::kTuple:
case Config::kNamedTuple:
Expand All @@ -496,7 +516,7 @@ TreeSpec<Aux> from_str_internal(
} else if (Config::kCustom == kind_char) {
kind = Kind::Custom;
read_idx++;
assert(spec[read_idx] == '(');
assert(spec.at(read_idx) == '(');
auto type_str_end = spec_data[read_idx];
read_idx++;
custom_type = spec.substr(read_idx, type_str_end - read_idx);
Expand All @@ -515,10 +535,15 @@ TreeSpec<Aux> from_str_internal(
size_t leaves_offset = 0;

if (size > 0) {
while (spec[read_idx] != Config::kNodeDataEnd) {
while (spec.at(read_idx) != Config::kNodeDataEnd) {
// NOLINTNEXTLINE
auto next_delim_idx = spec_data[read_idx];
read_idx++;
if (child_idx >= size) {
throw std::out_of_range(
"bounds check failed writing to pytree item at index " +
std::to_string(child_idx));
}
c->items[child_idx] =
from_str_internal<Aux>(spec, read_idx, spec_data);
read_idx = next_delim_idx;
Expand All @@ -541,11 +566,16 @@ TreeSpec<Aux> from_str_internal(
size_t leaves_offset = 0;

if (size > 0) {
while (spec[read_idx] != Config::kNodeDataEnd) {
while (spec.at(read_idx) != Config::kNodeDataEnd) {
// NOLINTNEXTLINE
auto next_delim_idx = spec_data[read_idx];
read_idx++;
if (spec[read_idx] == Config::kDictStrKeyQuote) {
if (child_idx >= size) {
throw std::out_of_range(
"bounds check failed decoding pytree dict at index " +
std::to_string(child_idx));
}
if (spec.at(read_idx) == Config::kDictStrKeyQuote) {
auto key_delim_idx = spec_data[read_idx];
read_idx++;
const size_t key_len = key_delim_idx - read_idx;
Expand All @@ -562,7 +592,7 @@ TreeSpec<Aux> from_str_internal(
c->items[child_idx] =
from_str_internal<Aux>(spec, read_idx, spec_data);
read_idx = next_delim_idx;
leaves_offset += layout[child_idx++];
leaves_offset += layout.at(child_idx++);
}
} else {
read_idx++;
Expand Down Expand Up @@ -605,7 +635,9 @@ struct stack final {
}
};

// We guarantee indicies in the result are in bounds.
inline arr<size_t> pre_parse(const StrTreeSpec& spec) {
// Invariant: indices in stack are in bounds.
stack<std::pair<size_t, size_t>> stack;
size_t i = 0;
const size_t size = spec.size();
Expand All @@ -627,11 +659,16 @@ inline arr<size_t> pre_parse(const StrTreeSpec& spec) {
case Config::kDictStrKeyQuote: {
size_t idx = i;
i++;
while (spec[i] != Config::kDictStrKeyQuote) {
while (spec.at(i) != Config::kDictStrKeyQuote) {
i++;
}
ret[idx] = i;
ret[i] = idx;
if (i >= size) {
throw std::out_of_range(
"bounds check failed while parsing dictionary key at index " +
std::to_string(i));
}
ret.at(idx) = i;
ret.at(i) = idx;
break;
}
case Config::kChildrenSep: {
Expand Down
20 changes: 20 additions & 0 deletions extension/pytree/test/test_pytree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ using Leaf = int32_t;
TEST(PyTreeTest, ArrBasic) {
arr<int> x(5);
ASSERT_EQ(x.size(), 5);
EXPECT_THROW(x.at(5), std::out_of_range);
for (int ii = 0; ii < x.size(); ++ii) {
x[ii] = 2 * ii;
}
Expand Down Expand Up @@ -197,3 +198,22 @@ TEST(pytree, FlattenNestedDict) {
ASSERT_EQ(*leaves[i], items[i]);
}
}

TEST(pytree, EmptySpec) {
Leaf items[1] = {9};
EXPECT_THROW(unflatten("", items), std::out_of_range);
}

TEST(pytree, BoundsCheckListLayout) {
// Malformed: layout one child, have two
std::string spec = "L1#1($,$)";
Leaf items[2] = {11, 12};
EXPECT_THROW(unflatten(spec, items), std::out_of_range);
}

TEST(pytree, BoundsCheckDictLayout) {
// Malformed: layout one child, have two.
std::string spec = "D1#1('key0':$,'key1':$)";
Leaf items[2] = {11, 12};
EXPECT_THROW(unflatten(spec, items), std::out_of_range);
}
Loading