From bb084a66a34af6a3dd68ea3f713242513b40615e Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 18 Mar 2025 13:48:10 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- kernels/portable/cpu/util/broadcast_indexes_range.h | 5 +++++ .../portable/cpu/util/test/broadcast_indexes_range_test.cpp | 6 ++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/kernels/portable/cpu/util/broadcast_indexes_range.h b/kernels/portable/cpu/util/broadcast_indexes_range.h index 7b78f4c2814..aaf7207d0c9 100644 --- a/kernels/portable/cpu/util/broadcast_indexes_range.h +++ b/kernels/portable/cpu/util/broadcast_indexes_range.h @@ -122,6 +122,11 @@ class BroadcastIndexesIterator { } output_index() += n; + if (output_dim_or_zero_if_no_broadcasting_ == 0) { + std::fill( + current_indexes_.begin() + 1, current_indexes_.end(), output_index()); + return *this; + } delinearize_index( output_index(), output_shape_, diff --git a/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp b/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp index 519cd9fe9f9..1023915ea66 100644 --- a/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp +++ b/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp @@ -44,7 +44,9 @@ TEST(BroadcastIndexesRangeTest, OneDNotBroadcasted) { Tensor out = tf.zeros({5}); int idx = 0; - for (const auto& elem : range_to_vec(BroadcastIndexesRange<1>(out, out))) { + const auto range = BroadcastIndexesRange<1>(out, out); + for (const auto& elem : range_to_vec(range)) { + EXPECT_EQ(*(range.begin() + idx), elem); EXPECT_EQ(elem[0], idx++); EXPECT_EQ(elem[0], elem[1]); } @@ -71,7 +73,7 @@ TEST(BroadcastIndexesRangeTest, ScalarBroadcastToOneD) { template void test_operator_plus(const Range& range) { size_t idx = 0; - for (const auto indexes : range) { + for (const auto& indexes : range) { EXPECT_EQ(*(range.begin() + idx), indexes); idx++; }