Skip to content

Commit 6890f61

Browse files
committed
add a bunch of bounds checking to pytree
Pull Request resolved: #7654 It's possible to pass arbitrary string input to pytree from Python; let's not have a bunch of low-hanging memory safety issues. ghstack-source-id: 261987905 @exported-using-ghexport Differential Revision: [D68166303](https://our.internmc.facebook.com/intern/diff/D68166303/)
1 parent e7530c4 commit 6890f61

File tree

2 files changed

+63
-14
lines changed

2 files changed

+63
-14
lines changed

extension/pytree/pytree.h

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ struct Key {
6060
std::variant<std::monostate, KeyInt, KeyStr> repr_;
6161

6262
public:
63-
Key() {}
63+
Key() = default;
6464
/*implicit*/ Key(KeyInt key) : repr_(key) {}
6565
/*implicit*/ Key(KeyStr key) : repr_(std::move(key)) {}
6666

@@ -131,7 +131,7 @@ struct ContainerHandle {
131131
using leaf_type = T;
132132
std::unique_ptr<container_type> handle;
133133

134-
ContainerHandle() {}
134+
ContainerHandle() = default;
135135

136136
template <typename... Args>
137137
ContainerHandle(Args... args)
@@ -427,6 +427,20 @@ struct arr {
427427
return data_[idx];
428428
}
429429

430+
T& at(size_t idx) {
431+
if (idx >= size()) {
432+
throw std::out_of_range("bounds check failed in pytree arr");
433+
}
434+
return data_[idx];
435+
}
436+
437+
const T& at(size_t idx) const {
438+
if (idx >= size()) {
439+
throw std::out_of_range("bounds check failed in pytree arr");
440+
}
441+
return data_[idx];
442+
}
443+
430444
inline T* data() {
431445
return data_.get();
432446
}
@@ -458,7 +472,7 @@ struct arr {
458472

459473
inline size_t read_number(const StrTreeSpec& spec, size_t& read_idx) {
460474
size_t num = 0;
461-
while (isdigit(spec[read_idx])) {
475+
while (isdigit(spec.at(read_idx))) {
462476
num = 10 * num + (spec[read_idx] - '0');
463477
read_idx++;
464478
}
@@ -470,19 +484,22 @@ inline arr<size_t> read_node_layout(const StrTreeSpec& spec, size_t& read_idx) {
470484
arr<size_t> ret(child_num);
471485

472486
size_t child_idx = 0;
473-
while (spec[read_idx] == Config::kChildrenDataSep) {
487+
while (spec.at(read_idx) == Config::kChildrenDataSep) {
474488
++read_idx;
475-
ret[child_idx++] = read_number(spec, read_idx);
489+
ret.at(child_idx++) = read_number(spec, read_idx);
476490
}
477491
return ret;
478492
}
479493

494+
// spec_data comes from pre_parse, which guarantees 1)
495+
// spec_data.size() == spec.size() and 2) contents of spec_data are
496+
// in-bounds indices for spec, so we omit bounds checks for spec_data.
480497
template <typename Aux>
481498
TreeSpec<Aux> from_str_internal(
482499
const StrTreeSpec& spec,
483500
size_t read_idx,
484501
const arr<size_t>& spec_data) {
485-
const auto kind_char = spec[read_idx];
502+
const auto kind_char = spec.at(read_idx);
486503
switch (kind_char) {
487504
case Config::kTuple:
488505
case Config::kNamedTuple:
@@ -496,7 +513,7 @@ TreeSpec<Aux> from_str_internal(
496513
} else if (Config::kCustom == kind_char) {
497514
kind = Kind::Custom;
498515
read_idx++;
499-
assert(spec[read_idx] == '(');
516+
assert(spec.at(read_idx) == '(');
500517
auto type_str_end = spec_data[read_idx];
501518
read_idx++;
502519
custom_type = spec.substr(read_idx, type_str_end - read_idx);
@@ -515,10 +532,14 @@ TreeSpec<Aux> from_str_internal(
515532
size_t leaves_offset = 0;
516533

517534
if (size > 0) {
518-
while (spec[read_idx] != Config::kNodeDataEnd) {
535+
while (spec.at(read_idx) != Config::kNodeDataEnd) {
519536
// NOLINTNEXTLINE
520537
auto next_delim_idx = spec_data[read_idx];
521538
read_idx++;
539+
if (child_idx >= size) {
540+
throw std::out_of_range(
541+
"bounds check failed writing to pytree item");
542+
}
522543
c->items[child_idx] =
523544
from_str_internal<Aux>(spec, read_idx, spec_data);
524545
read_idx = next_delim_idx;
@@ -541,11 +562,14 @@ TreeSpec<Aux> from_str_internal(
541562
size_t leaves_offset = 0;
542563

543564
if (size > 0) {
544-
while (spec[read_idx] != Config::kNodeDataEnd) {
565+
while (spec.at(read_idx) != Config::kNodeDataEnd) {
545566
// NOLINTNEXTLINE
546567
auto next_delim_idx = spec_data[read_idx];
547568
read_idx++;
548-
if (spec[read_idx] == Config::kDictStrKeyQuote) {
569+
if (child_idx >= size) {
570+
throw std::out_of_range("bounds check failed decoding pytree dict");
571+
}
572+
if (spec.at(read_idx) == Config::kDictStrKeyQuote) {
549573
auto key_delim_idx = spec_data[read_idx];
550574
read_idx++;
551575
const size_t key_len = key_delim_idx - read_idx;
@@ -562,7 +586,7 @@ TreeSpec<Aux> from_str_internal(
562586
c->items[child_idx] =
563587
from_str_internal<Aux>(spec, read_idx, spec_data);
564588
read_idx = next_delim_idx;
565-
leaves_offset += layout[child_idx++];
589+
leaves_offset += layout.at(child_idx++);
566590
}
567591
} else {
568592
read_idx++;
@@ -605,7 +629,9 @@ struct stack final {
605629
}
606630
};
607631

632+
// We guarantee indicies in the result are in bounds.
608633
inline arr<size_t> pre_parse(const StrTreeSpec& spec) {
634+
// Invariant: indices in stack are in bounds.
609635
stack<std::pair<size_t, size_t>> stack;
610636
size_t i = 0;
611637
const size_t size = spec.size();
@@ -627,11 +653,15 @@ inline arr<size_t> pre_parse(const StrTreeSpec& spec) {
627653
case Config::kDictStrKeyQuote: {
628654
size_t idx = i;
629655
i++;
630-
while (spec[i] != Config::kDictStrKeyQuote) {
656+
while (spec.at(i) != Config::kDictStrKeyQuote) {
631657
i++;
632658
}
633-
ret[idx] = i;
634-
ret[i] = idx;
659+
if (i >= size) {
660+
throw std::out_of_range(
661+
"bounds check failed while parsing dictionary key");
662+
}
663+
ret.at(idx) = i;
664+
ret.at(i) = idx;
635665
break;
636666
}
637667
case Config::kChildrenSep: {

extension/pytree/test/test_pytree.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,22 @@ TEST(pytree, FlattenNestedDict) {
183183
ASSERT_EQ(*leaves[i], items[i]);
184184
}
185185
}
186+
187+
TEST(pytree, EmptySpec) {
188+
Leaf items[1] = {9};
189+
EXPECT_THROW(unflatten("", items), std::out_of_range);
190+
}
191+
192+
TEST(pytree, BoundsCheckListLayout) {
193+
// Malformed: layout one child, have two
194+
std::string spec = "L1#1($,$)";
195+
Leaf items[2] = {11, 12};
196+
EXPECT_THROW(unflatten(spec, items), std::out_of_range);
197+
}
198+
199+
TEST(pytree, BoundsCheckDictLayout) {
200+
// Malformed: layout one child, have two.
201+
std::string spec = "D1#1('key0':$,'key1':$)";
202+
Leaf items[2] = {11, 12};
203+
EXPECT_THROW(unflatten(spec, items), std::out_of_range);
204+
}

0 commit comments

Comments
 (0)