Skip to content

Commit b9c056e

Browse files
committed
Use tensor_shape_to_c_string for error in check_mask_indices
Rolling out for #7902 ghstack-source-id: 7f375bd ghstack-comment-id: 2643854240 Pull Request resolved: #8314
1 parent 883d33a commit b9c056e

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

kernels/portable/cpu/util/advanced_index_util.cpp

+17-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
10+
#include <executorch/runtime/core/exec_aten/util/tensor_shape_to_c_string.h>
1011
#include <executorch/runtime/kernel/kernel_includes.h>
1112

1213
namespace torch {
@@ -49,9 +50,22 @@ bool check_mask_indices(const Tensor& in, TensorOptList indices) {
4950
ET_LOG_MSG_AND_RETURN_IF_FALSE(
5051
index.dim() > 0, "Zero-dimensional mask index not allowed");
5152
for (auto j = 0; j < index.dim(); j++) {
52-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
53-
index.size(j) == in.size(in_i + j),
54-
"The shape of mask index must match the sizes of the corresponding input dimensions.");
53+
if (index.size(j) != in.size(in_i + j)) {
54+
#ifdef ET_LOG_ENABLED
55+
auto mask_shape = executorch::runtime::tensor_shape_to_c_string(
56+
executorch::runtime::Span<const Tensor::SizesType>(
57+
index.sizes().data(), index.sizes().size()));
58+
auto input_shape = executorch::runtime::tensor_shape_to_c_string(
59+
executorch::runtime::Span<const Tensor::SizesType>(
60+
in.sizes().data() + in_i, index.sizes().size()));
61+
#endif // ET_LOG_ENABLED
62+
ET_LOG(
63+
Error,
64+
"The shape of mask index %s must match the sizes of the corresponding input dimensions %s.",
65+
mask_shape.data(),
66+
input_shape.data());
67+
return false;
68+
}
5569
}
5670
in_i += index.dim();
5771
} else {

kernels/portable/cpu/util/targets.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def define_common_targets():
117117
compiler_flags = ["-Wno-missing-prototypes"],
118118
deps = [
119119
":broadcast_util",
120+
"//executorch/runtime/core/exec_aten/util:tensor_shape_to_c_string",
120121
"//executorch/runtime/kernel:kernel_includes",
121122
],
122123
visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/..."],

0 commit comments

Comments
 (0)