Skip to content

Commit 5a7345b

Browse files
authored
add begin/end to ExecuTorch pytree::arr
Differential Revision: D68166302 Pull Request resolved: #7653
1 parent 3337fe5 commit 5a7345b

File tree

3 files changed

+32
-2
lines changed

3 files changed

+32
-2
lines changed

extension/pytree/aten_util/ivalue_util.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ std::pair<std::vector<at::Tensor>, std::unique_ptr<TreeSpec<Empty>>> flatten(
131131
auto p = flatten(c);
132132

133133
std::vector<at::Tensor> tensors;
134-
for (int i = 0; i < p.first.size(); ++i) {
135-
tensors.emplace_back(p.first[i]->toTensor());
134+
for (const auto& item : p.first) {
135+
tensors.emplace_back(item->toTensor());
136136
}
137137

138138
return {tensors, std::move(p.second)};

extension/pytree/pytree.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,22 @@ struct arr {
431431
return data_.get();
432432
}
433433

434+
T* begin() {
435+
return data_.get();
436+
}
437+
438+
T* end() {
439+
return begin() + size();
440+
}
441+
442+
const T* begin() const {
443+
return data_.get();
444+
}
445+
446+
const T* end() const {
447+
return begin() + size();
448+
}
449+
434450
inline size_t size() const {
435451
return n_;
436452
}

extension/pytree/test/test_pytree.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,27 @@
1111
#include <gtest/gtest.h>
1212
#include <string>
1313

14+
using ::executorch::extension::pytree::arr;
1415
using ::executorch::extension::pytree::ContainerHandle;
1516
using ::executorch::extension::pytree::Key;
1617
using ::executorch::extension::pytree::Kind;
1718
using ::executorch::extension::pytree::unflatten;
1819

1920
using Leaf = int32_t;
2021

22+
TEST(PyTreeTest, ArrBasic) {
23+
arr<int> x(5);
24+
ASSERT_EQ(x.size(), 5);
25+
for (int ii = 0; ii < x.size(); ++ii) {
26+
x[ii] = 2 * ii;
27+
}
28+
int idx = 0;
29+
for (const auto item : x) {
30+
EXPECT_EQ(item, 2 * idx);
31+
++idx;
32+
}
33+
}
34+
2135
TEST(PyTreeTest, List) {
2236
Leaf items[2] = {11, 12};
2337
std::string spec = "L2#1#1($,$)";

0 commit comments

Comments
 (0)