diff --git a/kernels/portable/cpu/util/broadcast_indexes_range.h b/kernels/portable/cpu/util/broadcast_indexes_range.h index c623fdb4c31..5fa50d8d212 100644 --- a/kernels/portable/cpu/util/broadcast_indexes_range.h +++ b/kernels/portable/cpu/util/broadcast_indexes_range.h @@ -14,6 +14,7 @@ #include #include +#include #include #include @@ -78,7 +79,9 @@ class BroadcastIndexesIterator { // You might wonder what happens if output_shape_[ii] == 0. In // that case, output.numel() would be 0, and thus we would have // begin() == end() and no iteration. - if ET_UNLIKELY (delinearized_output_index_[ii] == output_shape_[ii] - 1) { + if ET_UNLIKELY ( + static_cast(delinearized_output_index_[ii]) == + output_shape_[ii] - 1) { const auto old_delinearized_output_index_item = delinearized_output_index_[ii]; delinearized_output_index_[ii] = 0; @@ -104,11 +107,42 @@ class BroadcastIndexesIterator { return it; } + BroadcastIndexesIterator& operator+=(difference_type n) { + if (n <= 3) { + std::advance(*this, n); + return *this; + } + + output_index() += n; + delinearize_index( + output_index(), + output_shape_, + delinearized_output_index_.data(), + delinearized_output_index_.size()); + for (const auto ii : c10::irange(1, kNumInputs + 1)) { + current_indexes_[ii] = 0; + for (const auto jj : c10::irange(output_dim_)) { + current_indexes_[ii] += delinearized_output_index_[jj] * + effective_input_broadcast_strides_[ii - 1][jj]; + } + } + return *this; + } + + BroadcastIndexesIterator operator+(difference_type n) { + auto it = *this; + it += n; + return it; + } + difference_type operator-(const BroadcastIndexesIterator& rhs) const { return difference_type(output_index() - rhs.output_index()); } private: + using ShapeType = + std::array; + ssize_t output_index() const { return current_indexes_[0]; } @@ -117,11 +151,10 @@ class BroadcastIndexesIterator { return current_indexes_[0]; } - std::array - effective_input_broadcast_stride(const Tensor& output, const Tensor& t) - const { - std::array - result = {0}; + ShapeType effective_input_broadcast_stride( + const Tensor& output, + const Tensor& t) const { + ShapeType result = {0}; ET_CHECK_MSG( t.dim() <= output.dim(), "input to broadcasting op should have dim at most output dim, but %d > %d!", @@ -146,8 +179,6 @@ class BroadcastIndexesIterator { // The 0th entry is the current linear index into the output, // followed by kNumInputs input indexes. std::array current_indexes_ = {0}; - using ShapeType = std:: - array; ShapeType delinearized_output_index_ = {0}; ssize_t output_dim_; ArrayRef output_shape_; diff --git a/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp b/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp index f147958558d..519cd9fe9f9 100644 --- a/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp +++ b/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp @@ -68,6 +68,15 @@ TEST(BroadcastIndexesRangeTest, ScalarBroadcastToOneD) { EXPECT_EQ(expected, actual); } +template +void test_operator_plus(const Range& range) { + size_t idx = 0; + for (const auto indexes : range) { + EXPECT_EQ(*(range.begin() + idx), indexes); + idx++; + } +} + // [1] -> [H, W] // [W] -> [H, W] // [1, 1] -> [H, W] @@ -87,14 +96,15 @@ TEST(BroadcastIndexesRangeTest, OneAndTwoDExhaustive) { Tensor in_not_broadcast = tf.zeros({3, 4}); - auto actual = range_to_vec(BroadcastIndexesRange<6>( + const auto range = BroadcastIndexesRange<6>( out, in_0d_scalar, in_1d_scalar, in_2d_scalar, in_row, in_col, - in_not_broadcast)); + in_not_broadcast); + auto actual = range_to_vec(range); decltype(actual) expected = { {0, 0, 0, 0, 0, 0, 0}, {1, 0, 0, 0, 1, 0, 1}, @@ -110,6 +120,8 @@ TEST(BroadcastIndexesRangeTest, OneAndTwoDExhaustive) { {11, 0, 0, 0, 3, 2, 11}, }; EXPECT_EQ(expected, actual); + + test_operator_plus(range); } // Make sure nothing is thrown off by a size-1 dim in the output: @@ -138,20 +150,20 @@ TEST(BroadcastIndexesRangeTest, OneAndTwoDWith1InOutputShapeExhaustive) { Tensor in_col = tf.zeros({H, 1}); size_t idx = 0; + const auto range_row = BroadcastIndexesRange<5>( + out_row, + in_0d_scalar, + in_1d_scalar, + in_2d_scalar, + in_row, + in_leading_one_row); for (const auto [out_idx, in_0d_idx, in_1d_idx, in_2d_idx, in_row_idx, - in_leading_one_row_idx] : - BroadcastIndexesRange<5>( - out_row, - in_0d_scalar, - in_1d_scalar, - in_2d_scalar, - in_row, - in_leading_one_row)) { + in_leading_one_row_idx] : range_row) { EXPECT_EQ(out_idx, idx++); EXPECT_EQ(in_0d_idx, 0); EXPECT_EQ(in_1d_idx, 0); @@ -160,16 +172,21 @@ TEST(BroadcastIndexesRangeTest, OneAndTwoDWith1InOutputShapeExhaustive) { EXPECT_EQ(in_leading_one_row_idx, out_idx); } + test_operator_plus(range_row); + idx = 0; + const auto range_col = BroadcastIndexesRange<4>( + out_col, in_0d_scalar, in_1d_scalar, in_2d_scalar, in_col); for (const auto [out_idx, in_0d_idx, in_1d_idx, in_2d_idx, in_col_idx] : - BroadcastIndexesRange<4>( - out_col, in_0d_scalar, in_1d_scalar, in_2d_scalar, in_col)) { + range_col) { EXPECT_EQ(out_idx, idx++); EXPECT_EQ(in_0d_idx, 0); EXPECT_EQ(in_1d_idx, 0); EXPECT_EQ(in_2d_idx, 0); EXPECT_EQ(in_col_idx, out_idx); } + + test_operator_plus(range_col); } // [1, 1, 1] -> [C, H, W] @@ -197,16 +214,17 @@ TEST(BroadcastIndexesRangeTest, ThreeDBroadcasting) { // take the opportunity to mutation test against delinearize_index // and linearize_access_indexes. int idx = 0; - for (const auto indexes : BroadcastIndexesRange<8>( - out, - input_tensors[0], - input_tensors[1], - input_tensors[2], - input_tensors[3], - input_tensors[4], - input_tensors[5], - input_tensors[6], - input_tensors[7])) { + const auto range = BroadcastIndexesRange<8>( + out, + input_tensors[0], + input_tensors[1], + input_tensors[2], + input_tensors[3], + input_tensors[4], + input_tensors[5], + input_tensors[6], + input_tensors[7]); + for (const auto indexes : range) { const auto out_idx = indexes[0]; EXPECT_EQ(out_idx, idx++); size_t out_indexes[executorch::runtime::kTensorDimensionLimit]; @@ -219,6 +237,7 @@ TEST(BroadcastIndexesRangeTest, ThreeDBroadcasting) { out_indexes, out.dim(), input_tensors[tensor_idx])); } } + test_operator_plus(range); } // 4-D should generalize, but we will go ahead and test: @@ -235,8 +254,9 @@ void four_d_broadcasting_test() { // take the opportunity to mutation test against delinearize_index // and linearize_access_indexes. int idx = 0; - for (const auto [out_idx, in_cw_idx, in_nh_idx] : - BroadcastIndexesRange<2>(out, in_broadcast_cw, in_broadcast_nh)) { + const auto range = + BroadcastIndexesRange<2>(out, in_broadcast_cw, in_broadcast_nh); + for (const auto [out_idx, in_cw_idx, in_nh_idx] : range) { EXPECT_EQ(out_idx, idx++); size_t out_indexes[executorch::runtime::kTensorDimensionLimit]; delinearize_index( @@ -248,6 +268,8 @@ void four_d_broadcasting_test() { in_nh_idx, linearize_access_indexes(out_indexes, out.dim(), in_broadcast_nh)); } + + test_operator_plus(range); } TEST(BroadcastIndexesRangeTest, FourDBroadcasting) {