Skip to content

Commit 2b90570

Browse files
authored
add BroadcastIndexesRange (#8864)
See class comment. In brief, this adds an iterable range to make broadcasting ops convenient and efficient to implement.
1 parent e37129d commit 2b90570

File tree

6 files changed

+431
-2
lines changed

6 files changed

+431
-2
lines changed
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <algorithm>
12+
#include <array>
13+
#include <cstdint>
14+
#include <iterator>
15+
#include <tuple>
16+
17+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
18+
#include <executorch/runtime/core/exec_aten/util/tensor_dimension_limit.h>
19+
20+
namespace torch::executor {
21+
22+
namespace internal {
23+
template <std::size_t kNumInputs>
24+
class BroadcastIndexesIterator {
25+
public:
26+
using difference_type = ssize_t;
27+
using value_type = std::array<ssize_t, kNumInputs + 1>;
28+
using reference = const value_type&;
29+
using pointer = const value_type*;
30+
using iterator_category = std::forward_iterator_tag;
31+
32+
BroadcastIndexesIterator() = default;
33+
34+
template <typename... Args>
35+
explicit BroadcastIndexesIterator(const Tensor& output, const Args&... args)
36+
: output_dim_(output.dim()),
37+
output_shape_(output.sizes()),
38+
effective_input_broadcast_strides_{
39+
effective_input_broadcast_stride(output, args)...} {
40+
static_assert(
41+
sizeof...(args) == kNumInputs && (std::is_same_v<Args, Tensor> && ...),
42+
"BroadcastIndexesIterator constructor requires kNumInputs input tensor"
43+
"arguments!");
44+
}
45+
46+
struct make_end_t {
47+
explicit constexpr make_end_t() = default;
48+
};
49+
50+
template <typename... Args>
51+
BroadcastIndexesIterator(make_end_t, const Tensor& t, const Args&... args)
52+
: current_indexes_{
53+
t.numel(),
54+
0,
55+
} {}
56+
57+
bool operator==(const BroadcastIndexesIterator& rhs) const {
58+
return output_index() == rhs.output_index();
59+
}
60+
61+
bool operator!=(const BroadcastIndexesIterator& rhs) const {
62+
return !operator==(rhs);
63+
}
64+
65+
reference operator*() const {
66+
return current_indexes_;
67+
}
68+
69+
pointer operator->() const {
70+
return &current_indexes_;
71+
}
72+
73+
BroadcastIndexesIterator& operator++() {
74+
output_index()++;
75+
// TODO: add optimization for particular input tensors not being
76+
// broadcasted?
77+
for (auto ii = output_dim_ - 1; ii >= 0; --ii) {
78+
// You might wonder what happens if output_shape_[ii] == 0. In
79+
// that case, output.numel() would be 0, and thus we would have
80+
// begin() == end() and no iteration.
81+
if ET_UNLIKELY (delinearized_output_index_[ii] == output_shape_[ii] - 1) {
82+
const auto old_delinearized_output_index_item =
83+
delinearized_output_index_[ii];
84+
delinearized_output_index_[ii] = 0;
85+
for (const auto jj : c10::irange(1, kNumInputs + 1)) {
86+
current_indexes_[jj] -= old_delinearized_output_index_item *
87+
effective_input_broadcast_strides_[jj - 1][ii];
88+
}
89+
} else {
90+
delinearized_output_index_[ii]++;
91+
for (const auto jj : c10::irange(1, kNumInputs + 1)) {
92+
current_indexes_.at(jj) +=
93+
effective_input_broadcast_strides_[jj - 1][ii];
94+
}
95+
break;
96+
}
97+
}
98+
return *this;
99+
}
100+
101+
BroadcastIndexesIterator operator++(int) {
102+
auto it = *this;
103+
operator++();
104+
return it;
105+
}
106+
107+
difference_type operator-(const BroadcastIndexesIterator& rhs) const {
108+
return difference_type(output_index() - rhs.output_index());
109+
}
110+
111+
private:
112+
ssize_t output_index() const {
113+
return current_indexes_[0];
114+
}
115+
116+
ssize_t& output_index() {
117+
return current_indexes_[0];
118+
}
119+
120+
std::array<exec_aten::SizesType, executorch::runtime::kTensorDimensionLimit>
121+
effective_input_broadcast_stride(const Tensor& output, const Tensor& t)
122+
const {
123+
std::array<exec_aten::SizesType, executorch::runtime::kTensorDimensionLimit>
124+
result = {0};
125+
ET_CHECK_MSG(
126+
t.dim() <= output.dim(),
127+
"input to broadcasting op should have dim at most output dim, but %d > %d!",
128+
(int)t.dim(),
129+
(int)output.dim());
130+
131+
const auto num_leading_ones = output.dim() - t.dim();
132+
for (const auto idx : c10::irange(num_leading_ones)) {
133+
result[idx] = 0;
134+
}
135+
const auto t_sizes = t.sizes();
136+
const auto t_strides = t.strides();
137+
for (const auto idx :
138+
c10::irange(num_leading_ones, num_leading_ones + t.dim())) {
139+
result[idx] = t_sizes[idx - num_leading_ones] == 1
140+
? 0
141+
: t_strides[idx - num_leading_ones];
142+
}
143+
return result;
144+
}
145+
146+
// The 0th entry is the current linear index into the output,
147+
// followed by kNumInputs input indexes.
148+
std::array<ssize_t, kNumInputs + 1> current_indexes_ = {0};
149+
using ShapeType = std::
150+
array<exec_aten::SizesType, executorch::runtime::kTensorDimensionLimit>;
151+
ShapeType delinearized_output_index_ = {0};
152+
ssize_t output_dim_;
153+
ArrayRef<exec_aten::SizesType> output_shape_;
154+
// The linear index for a broadcast tensor is
155+
// sum(delinearized_output_index_[i] * input_stride_[i] if
156+
// padded_input_shape_[i] != 1 else 0), where padded_input_shape is
157+
// input.sizes() with leading 1s added to make its size equal to
158+
// output_dim. This is straightforwardly implementable with an
159+
// adjusted stride array that contains 0s where the padded input
160+
// shape would contain 1s.
161+
std::array<ShapeType, kNumInputs> effective_input_broadcast_strides_ = {
162+
{{0}}};
163+
};
164+
} // namespace internal
165+
166+
/**
167+
* Efficient mechanism for looping over the index space for an output
168+
* tensor and kNumInputs possibly-broadcasted input tensors. Use as follows:
169+
*
170+
* auto* output_data = output.mutable_data_ptr<OutputType>();
171+
* const auto* a_data = a.mutable_data_ptr<AType>();
172+
* const auto* b_data = b.mutable_data_ptr<BType>();
173+
* for (const auto [output_index, a_index, b_index] :
174+
* BroadcastIndexesRange<2>(output, a, b)) {
175+
* // Access output_data[output_index], a_data[a_index], and b_data[b_index].
176+
* }
177+
*
178+
* (where OutputType, AType, and BType are known concrete types.)
179+
*
180+
* Unlike looping using delinearize_index() and
181+
* linearize_access_indexes(), BroadcastIndexesRange avoids expensive
182+
* division and modulo operations on each iteration.
183+
*/
184+
template <std::size_t kNumInputs>
185+
class BroadcastIndexesRange {
186+
public:
187+
using iterator = internal::BroadcastIndexesIterator<kNumInputs>;
188+
189+
template <typename... Args>
190+
BroadcastIndexesRange(const Tensor& output, const Args&... args)
191+
: tensors_{&output, (&args)...} {}
192+
193+
iterator begin() const {
194+
return std::apply(
195+
[](const auto&... args) { return iterator((*args)...); }, tensors_);
196+
}
197+
198+
iterator end() const {
199+
return std::apply(
200+
[](const auto&... args) {
201+
return iterator(typename iterator::make_end_t(), (*args)...);
202+
},
203+
tensors_);
204+
}
205+
206+
private:
207+
std::array<const Tensor*, kNumInputs + 1> tensors_;
208+
};
209+
} // namespace torch::executor

kernels/portable/cpu/util/targets.bzl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,19 @@ def define_common_targets():
278278
visibility = ["//executorch/kernels/portable/cpu/..."],
279279
)
280280

281+
runtime.cxx_library(
282+
name = "broadcast_indexes_range",
283+
exported_headers = ["broadcast_indexes_range.h"],
284+
deps = [
285+
"//executorch/runtime/core/exec_aten:lib",
286+
"//executorch/runtime/core/exec_aten/util:tensor_dimension_limit",
287+
],
288+
visibility = [
289+
"//executorch/...",
290+
"@EXECUTORCH_CLIENTS",
291+
],
292+
)
293+
281294
# Utility functions that can be used by operators that perform reduction
282295
for aten_mode in get_aten_mode_options():
283296
suffix = "_aten" if aten_mode else ""

kernels/portable/cpu/util/test/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../../..)
1919

2020
include(${EXECUTORCH_ROOT}/build/Test.cmake)
2121

22-
set(_test_srcs broadcast_test.cpp reduce_test.cpp)
22+
set(_test_srcs broadcast_indexes_range_test.cpp broadcast_test.cpp
23+
reduce_test.cpp
24+
)
2325

2426
et_cxx_test(
2527
kernels_portable_cpu_util_test SOURCES ${_test_srcs} EXTRA_LIBS

0 commit comments

Comments
 (0)