Skip to content

Commit da70c5d

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
linearize_access_indexes only needs dim of broadcast_to (#270)
Summary: Pull Request resolved: #270 The utility function `linearize_access_indexes` currently takes the `broadcast_to` tensor as the second argument. However, only the dimension of the tensor is needed. Updating the function, to take only the dimension, so that it can be used in scenarios where a broadcast_to tensor is not available, and only a broadcast shape/ndim is known. Reviewed By: kirklandsign Differential Revision: D49159987 fbshipit-source-id: c6532b2a48b823234accdc3e69edc0e21147658a
1 parent e9c5851 commit da70c5d

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

kernels/portable/cpu/util/broadcast_util.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,11 +263,11 @@ void delinearize_index(
263263

264264
size_t linearize_access_indexes(
265265
ArrayRef<size_t> indexes_broadcast_to,
266-
const Tensor& broadcast_to,
266+
ssize_t broadcast_to_ndim,
267267
const Tensor& broadcast_from) {
268-
size_t num_skip_dims = broadcast_to.dim() - broadcast_from.dim();
268+
size_t num_skip_dims = broadcast_to_ndim - broadcast_from.dim();
269269
ArrayRef<size_t> indexes_broadcast_from = indexes_broadcast_to.slice(
270-
num_skip_dims, broadcast_to.dim() - num_skip_dims);
270+
num_skip_dims, broadcast_to_ndim - num_skip_dims);
271271

272272
ET_CHECK(indexes_broadcast_from.size() == broadcast_from.dim());
273273

kernels/portable/cpu/util/broadcast_util.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -200,13 +200,13 @@ void delinearize_index(
200200
* broadcast_to tensor and itself.
201201
*
202202
* @param[in] indexes The tensor access indexes of broadcast_to tensor
203-
* @param[in] broadcast_to The tensor with the broadcasted shape.
203+
* @param[in] broadcast_to_ndim The number of dims of the broadcasted shape.
204204
* @param[in] broadcast_from The tensor to be broadcasted.
205205
* @returns The flattend index for broadcast_from tensor.
206206
*/
207207
size_t linearize_access_indexes(
208208
ArrayRef<size_t> indexes_broadcast_to,
209-
const Tensor& broadcast_to,
209+
ssize_t broadcast_to_ndim,
210210
const Tensor& broadcast_from);
211211

212212
//
@@ -241,10 +241,10 @@ inline void apply_binary_elementwise_fn(
241241
delinearize_index(i, out, out_indexes, kTensorDimensionLimit);
242242

243243
if (a_is_broadcasted) {
244-
a_linear_index = linearize_access_indexes(out_indexes, out, a);
244+
a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
245245
}
246246
if (b_is_broadcasted) {
247-
b_linear_index = linearize_access_indexes(out_indexes, out, b);
247+
b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
248248
}
249249
}
250250

@@ -290,13 +290,13 @@ inline void apply_ternary_elementwise_fn(
290290
delinearize_index(i, out, out_indexes, kTensorDimensionLimit);
291291

292292
if (a_is_broadcasted) {
293-
a_linear_index = linearize_access_indexes(out_indexes, out, a);
293+
a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
294294
}
295295
if (b_is_broadcasted) {
296-
b_linear_index = linearize_access_indexes(out_indexes, out, b);
296+
b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
297297
}
298298
if (c_is_broadcasted) {
299-
c_linear_index = linearize_access_indexes(out_indexes, out, c);
299+
c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c);
300300
}
301301
}
302302

kernels/portable/cpu/util/test/broadcast_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ TEST(BroadcastUtilTest, LinearizeIndex) {
169169
size_t test_indexes[] = {0, 0, 0, i};
170170
ArrayRef<size_t> broadcast_to_indexes(test_indexes);
171171
size_t linear_index = linearize_access_indexes(
172-
broadcast_to_indexes, broadcast_to, broadcast_from);
172+
broadcast_to_indexes, broadcast_to.dim(), broadcast_from);
173173
EXPECT_EQ(linear_index, 0);
174174
}
175175

@@ -179,7 +179,7 @@ TEST(BroadcastUtilTest, LinearizeIndex) {
179179
size_t test_indexes[] = {0, i, 2, 3};
180180
ArrayRef<size_t> broadcast_to_indexes(test_indexes);
181181
size_t linear_index = linearize_access_indexes(
182-
broadcast_to_indexes, broadcast_to, broadcast_from);
182+
broadcast_to_indexes, broadcast_to.dim(), broadcast_from);
183183
EXPECT_EQ(linear_index, 2);
184184
}
185185
}

0 commit comments

Comments
 (0)