Skip to content

Commit 439b762

Browse files
committed
[ExecuTorch] Add broadcasting support to optimized op_div
Summary: Similar to broadcast support in op_mul Test Plan: Tests added Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 3466308 Pull Request resolved: #8257
1 parent 3ec7524 commit 439b762

File tree

3 files changed

+81
-49
lines changed

3 files changed

+81
-49
lines changed

kernels/optimized/cpu/op_div.cpp

Lines changed: 27 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -120,48 +120,36 @@ Tensor& opt_div_out(
120120
out.numel());
121121
});
122122
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
123-
const Tensor* lhs;
124-
const Tensor* rhs;
123+
// Reason for using alpha is becasuse handle_broadcast_elementwise
124+
// is used for add and sub as well:
125125
if (selected_optimized_path ==
126-
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
127-
lhs = &b;
128-
rhs = &a;
126+
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
127+
selected_optimized_path ==
128+
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
129+
selected_optimized_path ==
130+
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) {
131+
// This behavior is a bit confusing.
132+
// Reason we swap out args here is because handle_broadcast_elementwise
133+
// handles this selected_optimized_path option a bit differently.
134+
// This should really be resoled in handle_broadcast_elementwise.
135+
// However, the current blocker is that handle_broadcast_elementwise tries
136+
// to be agnostic of op. This should be fixed, likely by moving lambda
137+
// creation to handle_broadcast_elementwise and it be aware of which op is
138+
// being executed.
139+
auto div_lambda = [](auto x, auto y, auto alpha) {
140+
(void)alpha;
141+
return y / x;
142+
};
143+
return torch::executor::handle_broadcast_elementwise(
144+
ctx, div_lambda, a, b, out, selected_optimized_path);
129145
} else {
130-
// Catch failure to update logic when subing new broadcasting possibility.
131-
ET_DCHECK(
132-
selected_optimized_path ==
133-
ElementwiseOptimizedPath::kBroadcast2dBy1d);
134-
lhs = &a;
135-
rhs = &b;
146+
auto div_lambda = [](auto x, auto y, auto alpha) {
147+
(void)alpha;
148+
return x / y;
149+
};
150+
return torch::executor::handle_broadcast_elementwise(
151+
ctx, div_lambda, a, b, out, selected_optimized_path);
136152
}
137-
auto error = resize_tensor(out, lhs->sizes());
138-
ET_KERNEL_CHECK_MSG(
139-
ctx,
140-
error == Error::Ok,
141-
InvalidArgument,
142-
out,
143-
"Failed to resize output tensor.");
144-
ET_SWITCH_REALB_TYPES(out_type, ctx, "sub.out", CTYPE, [&]() {
145-
using Vec = executorch::vec::Vectorized<CTYPE>;
146-
if (selected_optimized_path ==
147-
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
148-
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
149-
[](Vec x, Vec y) { return y / x; },
150-
out.mutable_data_ptr<CTYPE>(),
151-
lhs->const_data_ptr<CTYPE>(),
152-
rhs->const_data_ptr<CTYPE>(),
153-
lhs->sizes()[lhs->dim() - 2],
154-
lhs->sizes()[lhs->dim() - 1]);
155-
} else {
156-
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
157-
[](Vec x, Vec y) { return x / y; },
158-
out.mutable_data_ptr<CTYPE>(),
159-
lhs->const_data_ptr<CTYPE>(),
160-
rhs->const_data_ptr<CTYPE>(),
161-
lhs->sizes()[lhs->dim() - 2],
162-
lhs->sizes()[lhs->dim() - 1]);
163-
}
164-
});
165153
} else {
166154
ScalarType common_type = get_compute_type(a_type, b_type);
167155
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);

kernels/test/op_div_test.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,52 @@ class OpDivOutTest : public OperatorTest {
8383
ET_EXPECT_KERNEL_FAILURE(context_, op_div_out(a, b, out));
8484
}
8585

86+
template <ScalarType DTYPE>
87+
void test_broadcast_3D() {
88+
TensorFactory<DTYPE> tf_a;
89+
90+
Tensor a =
91+
tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
92+
Tensor b = tf_a.make({2, 1, 3}, /*data=*/{2, 3, 4, 5, 6, 7});
93+
94+
// Destination for output of mul.
95+
Tensor out =
96+
tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
97+
Tensor expected = tf_a.make(
98+
{2, 2, 3},
99+
/*data=*/
100+
{0.5000,
101+
0.6667,
102+
0.75002,
103+
2.0000,
104+
1.6667,
105+
1.5000,
106+
1.4000,
107+
1.3333,
108+
1.2857,
109+
2.0000,
110+
1.8333,
111+
1.7143});
112+
// Check that it matches the expected output.
113+
EXPECT_TENSOR_CLOSE_WITH_TOL(op_div_out(a, b, out), expected, 1e-4, 1e-4);
114+
expected = tf_a.make(
115+
{2, 2, 3},
116+
/*data=*/
117+
{2.0000,
118+
1.5000,
119+
1.3333,
120+
0.5000,
121+
0.6000,
122+
0.6667,
123+
0.7143,
124+
0.7500,
125+
0.7778,
126+
0.5000,
127+
0.5455,
128+
0.5833});
129+
EXPECT_TENSOR_CLOSE_WITH_TOL(op_div_out(b, a, out), expected, 1e-4, 1e-4);
130+
}
131+
86132
/**
87133
* Common testing for div operator, for float output types
88134
*/
@@ -457,6 +503,14 @@ TEST_F(OpDivOutTest, DynamicShapeUpperBoundLargerThanExpected) {
457503
EXPECT_TENSOR_CLOSE(out, expected_result);
458504
}
459505

506+
TEST_F(OpDivOutTest, BroadcastNDTest) {
507+
// Test 3D tensors
508+
test_broadcast_3D<ScalarType::Float>();
509+
// half and bfloat16 are not supported for div quite yet
510+
// test_broadcast_3D<ScalarType::Half>();
511+
// test_broadcast_3D<ScalarType::BFloat16>();
512+
}
513+
460514
TEST_F(OpDivOutTest, DynamicShapeUnbound) {
461515
GTEST_SKIP() << "Dynamic shape not supported";
462516
TensorFactory<ScalarType::Float> tf;

kernels/test/op_mul_test.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -417,16 +417,6 @@ TEST_F(OpMulOutTest, BroadcastA2BTest) {
417417
test_broadcast_a2b<ScalarType::Int>();
418418
test_broadcast_a2b<ScalarType::Half>();
419419
test_broadcast_a2b<ScalarType::BFloat16>();
420-
421-
// Test 3D tensors
422-
test_broadcast_3D<ScalarType::Float>();
423-
test_broadcast_3D<ScalarType::Half>();
424-
test_broadcast_3D<ScalarType::BFloat16>();
425-
426-
// Test 4D tensors
427-
test_broadcast_4D<ScalarType::Float>();
428-
test_broadcast_4D<ScalarType::Half>();
429-
test_broadcast_4D<ScalarType::BFloat16>();
430420
}
431421

432422
// Broadcast tensor a's size to tensor b's size

0 commit comments

Comments
 (0)