Skip to content

Commit 115da12

Browse files
committed
Include tensor shapes in get_broadcast_target_size error message
This is the motivating example for #7902. Test Plan: Injected failure to new broadcast_test and saw shapes in error message. ghstack-source-id: 6ae4789fa101cbcdedd316490e83a951460479bd ghstack-comment-id: 2613189501 Pull Request resolved: #7944
1 parent 0de8e76 commit 115da12

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

kernels/portable/cpu/util/broadcast_util.cpp

+14-4
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,20 @@ ET_NODISCARD Error get_broadcast_target_size(
213213
Tensor::SizesType* out_sizes,
214214
const size_t out_sizes_len,
215215
size_t* out_dim) {
216-
ET_CHECK_OR_RETURN_ERROR(
217-
tensors_are_broadcastable_between(a_size, b_size),
218-
InvalidArgument,
219-
"Two input tensors should be broadcastable.\n");
216+
if (!tensors_are_broadcastable_between(a_size, b_size)) {
217+
const auto a_shape_str = tensor_shape_to_c_string(
218+
executorch::runtime::Span<const Tensor::SizesType>(
219+
a_size.data(), a_size.size()));
220+
const auto b_shape_str = tensor_shape_to_c_string(
221+
executorch::runtime::Span<const Tensor::SizesType>(
222+
b_size.data(), b_size.size()));
223+
ET_LOG(
224+
Error,
225+
"Two input tensors should be broadcastable but got shapes %s and %s.",
226+
a_shape_str.data(),
227+
b_shape_str.data());
228+
return executorch::runtime::Error::InvalidArgument;
229+
}
220230

221231
auto a_dim = a_size.size();
222232
auto b_dim = b_size.size();

kernels/portable/cpu/util/test/broadcast_test.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,15 @@ TEST(BroadcastUtilTest, GetBroadcastTargetSize) {
129129
EXPECT_TRUE(
130130
ArrayRef<Tensor::SizesType>(expected_output_size, expected_output_dim)
131131
.equals(ArrayRef<Tensor::SizesType>({5, 2, 2})));
132+
133+
Tensor c = tf.zeros({4, 5});
134+
err = get_broadcast_target_size(
135+
a,
136+
c,
137+
expected_output_size,
138+
torch::executor::kTensorDimensionLimit,
139+
&expected_output_dim);
140+
EXPECT_EQ(err, torch::executor::Error::InvalidArgument);
132141
}
133142

134143
size_t linearize_indexes(size_t* indexes, size_t indexes_len, const Tensor& t) {

0 commit comments

Comments
 (0)