Skip to content

Commit ee7d388

Browse files
authored
Use tensor_shape_to_c_string for error in check_mask_indices (#8314)
Rolling out for #7902
1 parent 4e3a8bd commit ee7d388

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

kernels/portable/cpu/util/advanced_index_util.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
10+
#include <executorch/runtime/core/exec_aten/util/tensor_shape_to_c_string.h>
1011
#include <executorch/runtime/kernel/kernel_includes.h>
1112

1213
namespace torch {
@@ -49,9 +50,22 @@ bool check_mask_indices(const Tensor& in, TensorOptList indices) {
4950
ET_LOG_MSG_AND_RETURN_IF_FALSE(
5051
index.dim() > 0, "Zero-dimensional mask index not allowed");
5152
for (auto j = 0; j < index.dim(); j++) {
52-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
53-
index.size(j) == in.size(in_i + j),
54-
"The shape of mask index must match the sizes of the corresponding input dimensions.");
53+
if (index.size(j) != in.size(in_i + j)) {
54+
#ifdef ET_LOG_ENABLED
55+
auto mask_shape = executorch::runtime::tensor_shape_to_c_string(
56+
executorch::runtime::Span<const Tensor::SizesType>(
57+
index.sizes().data(), index.sizes().size()));
58+
auto input_shape = executorch::runtime::tensor_shape_to_c_string(
59+
executorch::runtime::Span<const Tensor::SizesType>(
60+
in.sizes().data() + in_i, index.sizes().size()));
61+
ET_LOG(
62+
Error,
63+
"The shape of mask index %s must match the sizes of the corresponding input dimensions %s.",
64+
mask_shape.data(),
65+
input_shape.data());
66+
#endif // ET_LOG_ENABLED
67+
return false;
68+
}
5569
}
5670
in_i += index.dim();
5771
} else {

kernels/portable/cpu/util/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def define_common_targets():
117117
compiler_flags = ["-Wno-missing-prototypes"],
118118
deps = [
119119
":broadcast_util",
120+
"//executorch/runtime/core/exec_aten/util:tensor_shape_to_c_string",
120121
"//executorch/runtime/kernel:kernel_includes",
121122
],
122123
visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/..."],

0 commit comments

Comments
 (0)