Skip to content

Commit 81b8ced

Browse files
committed
add a bunch of bounds checking to pytree
Pull Request resolved: #7654 It's possible to pass arbitrary string input to pytree from Python; let's not have a bunch of low-hanging memory safety issues. ghstack-source-id: 262806900 @exported-using-ghexport Differential Revision: [D68166303](https://our.internmc.facebook.com/intern/diff/D68166303/)
1 parent 67bf46a commit 81b8ced

File tree

2 files changed

+70
-14
lines changed

2 files changed

+70
-14
lines changed

extension/pytree/pytree.h

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ struct Key {
6060
std::variant<std::monostate, KeyInt, KeyStr> repr_;
6161

6262
public:
63-
Key() {}
63+
Key() = default;
6464
/*implicit*/ Key(KeyInt key) : repr_(key) {}
6565
/*implicit*/ Key(KeyStr key) : repr_(std::move(key)) {}
6666

@@ -131,7 +131,7 @@ struct ContainerHandle {
131131
using leaf_type = T;
132132
std::unique_ptr<container_type> handle;
133133

134-
ContainerHandle() {}
134+
ContainerHandle() = default;
135135

136136
template <typename... Args>
137137
ContainerHandle(Args... args)
@@ -427,6 +427,22 @@ struct arr {
427427
return data_[idx];
428428
}
429429

430+
T& at(size_t idx) {
431+
if (idx >= size()) {
432+
throw std::out_of_range(
433+
"bounds check failed in pytree arr at index " + std::to_string(idx));
434+
}
435+
return data_[idx];
436+
}
437+
438+
const T& at(size_t idx) const {
439+
if (idx >= size()) {
440+
throw std::out_of_range(
441+
"bounds check failed in pytree arr at index " + std::to_string(idx));
442+
}
443+
return data_[idx];
444+
}
445+
430446
inline T* data() {
431447
return data_.get();
432448
}
@@ -458,7 +474,7 @@ struct arr {
458474

459475
inline size_t read_number(const StrTreeSpec& spec, size_t& read_idx) {
460476
size_t num = 0;
461-
while (isdigit(spec[read_idx])) {
477+
while (isdigit(spec.at(read_idx))) {
462478
num = 10 * num + (spec[read_idx] - '0');
463479
read_idx++;
464480
}
@@ -470,19 +486,22 @@ inline arr<size_t> read_node_layout(const StrTreeSpec& spec, size_t& read_idx) {
470486
arr<size_t> ret(child_num);
471487

472488
size_t child_idx = 0;
473-
while (spec[read_idx] == Config::kChildrenDataSep) {
489+
while (spec.at(read_idx) == Config::kChildrenDataSep) {
474490
++read_idx;
475-
ret[child_idx++] = read_number(spec, read_idx);
491+
ret.at(child_idx++) = read_number(spec, read_idx);
476492
}
477493
return ret;
478494
}
479495

496+
// spec_data comes from pre_parse, which guarantees 1)
497+
// spec_data.size() == spec.size() and 2) contents of spec_data are
498+
// in-bounds indices for spec, so we omit bounds checks for spec_data.
480499
template <typename Aux>
481500
TreeSpec<Aux> from_str_internal(
482501
const StrTreeSpec& spec,
483502
size_t read_idx,
484503
const arr<size_t>& spec_data) {
485-
const auto kind_char = spec[read_idx];
504+
const auto kind_char = spec.at(read_idx);
486505
switch (kind_char) {
487506
case Config::kTuple:
488507
case Config::kNamedTuple:
@@ -496,7 +515,7 @@ TreeSpec<Aux> from_str_internal(
496515
} else if (Config::kCustom == kind_char) {
497516
kind = Kind::Custom;
498517
read_idx++;
499-
assert(spec[read_idx] == '(');
518+
assert(spec.at(read_idx) == '(');
500519
auto type_str_end = spec_data[read_idx];
501520
read_idx++;
502521
custom_type = spec.substr(read_idx, type_str_end - read_idx);
@@ -515,10 +534,15 @@ TreeSpec<Aux> from_str_internal(
515534
size_t leaves_offset = 0;
516535

517536
if (size > 0) {
518-
while (spec[read_idx] != Config::kNodeDataEnd) {
537+
while (spec.at(read_idx) != Config::kNodeDataEnd) {
519538
// NOLINTNEXTLINE
520539
auto next_delim_idx = spec_data[read_idx];
521540
read_idx++;
541+
if (child_idx >= size) {
542+
throw std::out_of_range(
543+
"bounds check failed writing to pytree item at index " +
544+
std::to_string(child_idx));
545+
}
522546
c->items[child_idx] =
523547
from_str_internal<Aux>(spec, read_idx, spec_data);
524548
read_idx = next_delim_idx;
@@ -541,11 +565,16 @@ TreeSpec<Aux> from_str_internal(
541565
size_t leaves_offset = 0;
542566

543567
if (size > 0) {
544-
while (spec[read_idx] != Config::kNodeDataEnd) {
568+
while (spec.at(read_idx) != Config::kNodeDataEnd) {
545569
// NOLINTNEXTLINE
546570
auto next_delim_idx = spec_data[read_idx];
547571
read_idx++;
548-
if (spec[read_idx] == Config::kDictStrKeyQuote) {
572+
if (child_idx >= size) {
573+
throw std::out_of_range(
574+
"bounds check failed decoding pytree dict at index " +
575+
std::to_string(child_idx));
576+
}
577+
if (spec.at(read_idx) == Config::kDictStrKeyQuote) {
549578
auto key_delim_idx = spec_data[read_idx];
550579
read_idx++;
551580
const size_t key_len = key_delim_idx - read_idx;
@@ -562,7 +591,7 @@ TreeSpec<Aux> from_str_internal(
562591
c->items[child_idx] =
563592
from_str_internal<Aux>(spec, read_idx, spec_data);
564593
read_idx = next_delim_idx;
565-
leaves_offset += layout[child_idx++];
594+
leaves_offset += layout.at(child_idx++);
566595
}
567596
} else {
568597
read_idx++;
@@ -605,7 +634,9 @@ struct stack final {
605634
}
606635
};
607636

637+
// We guarantee indicies in the result are in bounds.
608638
inline arr<size_t> pre_parse(const StrTreeSpec& spec) {
639+
// Invariant: indices in stack are in bounds.
609640
stack<std::pair<size_t, size_t>> stack;
610641
size_t i = 0;
611642
const size_t size = spec.size();
@@ -627,11 +658,16 @@ inline arr<size_t> pre_parse(const StrTreeSpec& spec) {
627658
case Config::kDictStrKeyQuote: {
628659
size_t idx = i;
629660
i++;
630-
while (spec[i] != Config::kDictStrKeyQuote) {
661+
while (spec.at(i) != Config::kDictStrKeyQuote) {
631662
i++;
632663
}
633-
ret[idx] = i;
634-
ret[i] = idx;
664+
if (i >= size) {
665+
throw std::out_of_range(
666+
"bounds check failed while parsing dictionary key at index " +
667+
std::to_string(i));
668+
}
669+
ret.at(idx) = i;
670+
ret.at(i) = idx;
635671
break;
636672
}
637673
case Config::kChildrenSep: {

extension/pytree/test/test_pytree.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ using Leaf = int32_t;
2222
TEST(PyTreeTest, ArrBasic) {
2323
arr<int> x(5);
2424
ASSERT_EQ(x.size(), 5);
25+
EXPECT_THROW(x.at(5), std::out_of_range);
2526
for (int ii = 0; ii < x.size(); ++ii) {
2627
x[ii] = 2 * ii;
2728
}
@@ -197,3 +198,22 @@ TEST(pytree, FlattenNestedDict) {
197198
ASSERT_EQ(*leaves[i], items[i]);
198199
}
199200
}
201+
202+
TEST(pytree, EmptySpec) {
203+
Leaf items[1] = {9};
204+
EXPECT_THROW(unflatten("", items), std::out_of_range);
205+
}
206+
207+
TEST(pytree, BoundsCheckListLayout) {
208+
// Malformed: layout one child, have two
209+
std::string spec = "L1#1($,$)";
210+
Leaf items[2] = {11, 12};
211+
EXPECT_THROW(unflatten(spec, items), std::out_of_range);
212+
}
213+
214+
TEST(pytree, BoundsCheckDictLayout) {
215+
// Malformed: layout one child, have two.
216+
std::string spec = "D1#1('key0':$,'key1':$)";
217+
Leaf items[2] = {11, 12};
218+
EXPECT_THROW(unflatten(spec, items), std::out_of_range);
219+
}

0 commit comments

Comments
 (0)