Skip to content

Commit cce5274

Browse files
committed
[ExecuTorch] Add broadcasting support to optimized op_div
Summary: Similar to broadcast support in op_mul Test Plan: Tests added Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: bedf1ef Pull Request resolved: #8257
1 parent 0d1596b commit cce5274

File tree

2 files changed

+69
-38
lines changed

2 files changed

+69
-38
lines changed

kernels/optimized/cpu/op_div.cpp

Lines changed: 15 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -120,46 +120,23 @@ Tensor& opt_div_out(
120120
out.numel());
121121
});
122122
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
123-
const Tensor* lhs;
124-
const Tensor* rhs;
125-
if (selected_optimized_path ==
126-
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
127-
lhs = &b;
128-
rhs = &a;
129-
} else {
130-
// Catch failure to update logic when subing new broadcasting possibility.
131-
ET_DCHECK(
132-
selected_optimized_path ==
133-
ElementwiseOptimizedPath::kBroadcast2dBy1d);
134-
lhs = &a;
135-
rhs = &b;
136-
}
137-
auto error = resize_tensor(out, lhs->sizes());
138-
ET_KERNEL_CHECK_MSG(
139-
ctx,
140-
error == Error::Ok,
141-
InvalidArgument,
142-
out,
143-
"Failed to resize output tensor.");
144-
ET_SWITCH_REALB_TYPES(out_type, ctx, "sub.out", CTYPE, [&]() {
145-
using Vec = executorch::vec::Vectorized<CTYPE>;
123+
// Reason for using alpha is becasuse handle_broadcast_elementwise
124+
// is used for add and sub as well:
125+
static constexpr const char op_name[] = "div.out";
126+
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
146127
if (selected_optimized_path ==
147-
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
148-
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
149-
[](Vec x, Vec y) { return y / x; },
150-
out.mutable_data_ptr<CTYPE>(),
151-
lhs->const_data_ptr<CTYPE>(),
152-
rhs->const_data_ptr<CTYPE>(),
153-
lhs->sizes()[lhs->dim() - 2],
154-
lhs->sizes()[lhs->dim() - 1]);
128+
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
129+
selected_optimized_path ==
130+
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
131+
selected_optimized_path ==
132+
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) {
133+
auto div_lambda = [](auto x, auto y) { return y / x; };
134+
return torch::executor::handle_broadcast_elementwise<CTYPE>(
135+
ctx, div_lambda, a, b, out, selected_optimized_path);
155136
} else {
156-
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
157-
[](Vec x, Vec y) { return x / y; },
158-
out.mutable_data_ptr<CTYPE>(),
159-
lhs->const_data_ptr<CTYPE>(),
160-
rhs->const_data_ptr<CTYPE>(),
161-
lhs->sizes()[lhs->dim() - 2],
162-
lhs->sizes()[lhs->dim() - 1]);
137+
auto div_lambda = [](auto x, auto y) { return x / y; };
138+
return torch::executor::handle_broadcast_elementwise<CTYPE>(
139+
ctx, div_lambda, a, b, out, selected_optimized_path);
163140
}
164141
});
165142
} else {

kernels/test/op_div_test.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,52 @@ class OpDivOutTest : public OperatorTest {
8383
ET_EXPECT_KERNEL_FAILURE(context_, op_div_out(a, b, out));
8484
}
8585

86+
template <ScalarType DTYPE>
87+
void test_broadcast_3D() {
88+
TensorFactory<DTYPE> tf_a;
89+
90+
Tensor a =
91+
tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
92+
Tensor b = tf_a.make({2, 1, 3}, /*data=*/{2, 3, 4, 5, 6, 7});
93+
94+
// Destination for output of mul.
95+
Tensor out =
96+
tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
97+
Tensor expected = tf_a.make(
98+
{2, 2, 3},
99+
/*data=*/
100+
{0.5000,
101+
0.6667,
102+
0.75002,
103+
2.0000,
104+
1.6667,
105+
1.5000,
106+
1.4000,
107+
1.3333,
108+
1.2857,
109+
2.0000,
110+
1.8333,
111+
1.7143});
112+
// Check that it matches the expected output.
113+
EXPECT_TENSOR_CLOSE_WITH_TOL(op_div_out(a, b, out), expected, 1e-4, 1e-4);
114+
expected = tf_a.make(
115+
{2, 2, 3},
116+
/*data=*/
117+
{2.0000,
118+
1.5000,
119+
1.3333,
120+
0.5000,
121+
0.6000,
122+
0.6667,
123+
0.7143,
124+
0.7500,
125+
0.7778,
126+
0.5000,
127+
0.5455,
128+
0.5833});
129+
EXPECT_TENSOR_CLOSE_WITH_TOL(op_div_out(b, a, out), expected, 1e-4, 1e-4);
130+
}
131+
86132
/**
87133
* Common testing for div operator, for float output types
88134
*/
@@ -457,6 +503,14 @@ TEST_F(OpDivOutTest, DynamicShapeUpperBoundLargerThanExpected) {
457503
EXPECT_TENSOR_CLOSE(out, expected_result);
458504
}
459505

506+
TEST_F(OpDivOutTest, BroadcastNDTest) {
507+
// Test 3D tensors
508+
test_broadcast_3D<ScalarType::Float>();
509+
// half and bfloat16 are not supported for div quite yet
510+
// test_broadcast_3D<ScalarType::Half>();
511+
// test_broadcast_3D<ScalarType::BFloat16>();
512+
}
513+
460514
TEST_F(OpDivOutTest, DynamicShapeUnbound) {
461515
GTEST_SKIP() << "Dynamic shape not supported";
462516
TensorFactory<ScalarType::Float> tf;

0 commit comments

Comments
 (0)