-
Notifications
You must be signed in to change notification settings - Fork 537
[ExecuTorch] Add broadcast support for optimized add op #8205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary: Refactoring broadcast handling utils that were added for op_mul. This is in prepartion use these utils to handle broadcast for other ops such as add, sub, div. Plus remove a redundant test Test Plan: optimized_kernels_test in CI Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/8205
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 6f2f01a with merge base 8148603 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
…imized add op" Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
kernels/optimized/cpu/binary_ops.h
Outdated
Tensor& handle_last_dim_broadcast_elementwise( | ||
KernelRuntimeContext& ctx, | ||
const Op& vec_fun, | ||
const Tensor& a, | ||
const Tensor& b, | ||
Tensor& out, | ||
const ElementwiseOptimizedPath selected_optimized_path) { | ||
const ElementwiseOptimizedPath selected_optimized_path, | ||
executorch::aten::optional<Scalar>& alpha = {}) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
error messages are telling you this needs to be a const ref. also why is this not std::optional
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah just realized that. not sure why it did not throw error in local build but different compile options i guess.
I just followed what i see elsewhere. Happy to switch to std::optional too which is what I guess is backing that but maybe for aten build it aliases c10:::optional? Let me check that first
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
c10::optional is gone
kernels/optimized/cpu/binary_ops.h
Outdated
CTYPE alpha_val; | ||
Vec alpha_val_vec(alpha_val); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
normally I would say "alpha_val needs to be initialized; C++ doesn't have default zero-initialization for primitives", but actually the problem here is that alpha_val
needs to move under the if
:
Vec alpha_val_vec;
if (alpha.has_value()) {
CTYPE alpha_val;
ET_KERNEL_CHECK(...)
kernels/optimized/cpu/binary_ops.h
Outdated
Tensor& handle_broadcast_elementwise( | ||
KernelRuntimeContext& ctx, | ||
const Op& vec_fun, | ||
const Tensor& a, | ||
const Tensor& b, | ||
Tensor& out, | ||
const ElementwiseOptimizedPath selected_optimized_path) { | ||
const ElementwiseOptimizedPath selected_optimized_path, | ||
executorch::aten::optional<Scalar> alpha = {}) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this by-value but the other one is a reference? make consistent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh good call out. my bad
kernels/optimized/cpu/binary_ops.h
Outdated
inner_size); | ||
ET_SWITCH_REALB_TYPES(out_type, ctx, internal::BinaryOpTypeName<op_type>::kName, CTYPE, [&]() { | ||
using Vec = executorch::vec::Vectorized<CTYPE>; | ||
CTYPE alpha_val; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same problem as above
kernels/optimized/cpu/op_add.cpp
Outdated
// This behavior is a bit confusing. | ||
// Reason we swap out args here is because handle_broadcast_elementwise | ||
// handles this selected_optimized_path option a bit differently. | ||
// This should really be resoled in handle_broadcast_elementwise. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/resoled/resolved/
kernels/optimized/cpu/op_add.cpp
Outdated
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments || | ||
selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) { | ||
// This behavior is a bit confusing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand what's confusing here; there is an argument that should be scaled by alpha_val, we have to scale the right one. definitely don't think handle_broadcast_elementwise should be coupled to the specific op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
problem is this. All the reverse arg stuff has specifically different handlking inside handle_broadcast_elementwise
. But that handling, inside handle_broadcast_elementwise
, to work this change is necessary which makes them coupled and fragile
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess confusing is not the right word here though.
kernels/optimized/cpu/op_mul.cpp
Outdated
@@ -130,8 +130,12 @@ Tensor& opt_mul_out( | |||
out.numel()); | |||
}); | |||
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) { | |||
auto mul_lambda = [](auto x, auto y) { return x * y; }; | |||
return torch::executor::handle_broadcast_elementwise( | |||
// Reason for using alpha: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing rest of comment after the colon
kernels/optimized/cpu/op_mul.cpp
Outdated
return torch::executor::handle_broadcast_elementwise( | ||
// Reason for using alpha: | ||
auto mul_lambda = [](auto x, auto y, auto alpha) { | ||
(void)alpha; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you :)
kernels/optimized/cpu/binary_ops.h
Outdated
template <BinaryOpType op_type> | ||
struct BinaryOpTypeName; | ||
|
||
template <> | ||
struct BinaryOpTypeName<BinaryOpType::kAdd> { | ||
static constexpr char kName[] = "add.out"; | ||
}; | ||
|
||
template <> | ||
struct BinaryOpTypeName<BinaryOpType::kSub> { | ||
static constexpr char kName[] = "sub.out"; | ||
}; | ||
|
||
template <> | ||
struct BinaryOpTypeName<BinaryOpType::kMul> { | ||
static constexpr char kName[] = "mul.out"; | ||
}; | ||
|
||
template <> | ||
struct BinaryOpTypeName<BinaryOpType::kDiv> { | ||
static constexpr char kName[] = "div.out"; | ||
}; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you don't need to do this. see existing example:
executorch/kernels/portable/cpu/op_rsub.cpp
Lines 50 to 55 in c82a7df
static constexpr const char op_name[] = "rsub.Scalar_out"; | |
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { | |
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b); | |
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha); | |
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>( |
the secret sauce is that the string literal has to be a static constexpr const char [] and then you can pass it to a const char*
template argument directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I was hoping you would point me to something better for this
…imized add op" Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned]
Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned]
…imized add op" Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned]
Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just re-request review if my suggestion is bad :)
kernels/optimized/cpu/op_add.cpp
Outdated
// creation to handle_broadcast_elementwise and it be aware of which op is | ||
// being executed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure I agree, but we can settle that on a review of a proposed change
kernels/optimized/cpu/binary_ops.h
Outdated
using Vec = executorch::vec::Vectorized<CTYPE>; | ||
Vec alpha_val_vec; | ||
if (alpha.has_value()) { | ||
CTYPE alpha_val; | ||
ET_KERNEL_CHECK( | ||
ctx, | ||
native::utils::extract_scalar(alpha.value(), &alpha_val), | ||
InvalidArgument, ); | ||
alpha_val_vec = Vec(alpha_val); | ||
} | ||
auto vec_fun_alpha = [vec_fun, alpha_val_vec](const Vec& a, const Vec& b) { | ||
return vec_fun(a, b, alpha_val_vec); | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, I see, you're having problems with the lambda because of this part. you can solve this by factoring the code differently.
the end result at the callsite could look something like
auto broadcast_op_plan_opt = plan_broadcast_elementwise(...); // broadcast_op_plan is a struct containing all the stuff you work out that isn't dependent on the dtype, like lhs, rhs. it does ET_KERNEL_CHECKs intenrally and returns nullopt if they fail.
if (!broadcast_op_plan_opt) {
// a check already failed
return;
}
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
auto alpha_val_vec_opt = extract_scalar_to_vector<CTYPE>(); // wrap up the bit that
if (!alpha_val_vec_opt) {
// awkward that this only returns from the lambda, but this is a generic ET_KERNEL_CHECK problem
return;
}
auto add_lambda = [alpha_val_vec = *alpha_val_vec_opt](auto x, auto y) {
return y + alpha_val * x;
};
execute_broadcast_elementwise_plan<CTYPE>(*broadcast_op_plan_opt, add_lambda, ...);
});
disclaimer: this is off the top of my head and it may be possible to unify some of this stuff with dtype_util.h for further simplification, though dtype_util is mostly intended to cut size/build time of portable ops
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Let me see if I dont run into other issues to enable such refactor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok so I looked refactor required. I think it is doable at the cost of moving ET_SWITCH_REALB_TYPES
macros to the callsite in respective ops. Downside here is that now if you enable new dtype for optimized path, you have to change all the callsites.
So I am not fully convinced that it is better go down that route. But want to see whats your reasoning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you have to change all the callsites.
that's just a matter of typing, right? if you plan to do it (I suppose optimizing Half/BFloat16 should be on our TODO list if the hardware supports the relevant instructions) and you really don't want to change 4-5 files later (you'll have to change them anyway for specifically Half/BFloat16 because there are opt-outs), you could always #define ET_SWITCH_OPTIMIZED_ELEMENTWISE_BROADCAST_OP_TYPES ET_SWITCH_REALB_TYPEs
pre-emptively.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok thats fair. But is your reasoning for this change simpler code or you see perf impact.
I am not too stuck to it, so I will just go ahead and do it but wanted to understand your reasoning
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
simpler less repetitive code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok will make the change but this will likely marginally increase size since now the whole handle_broadcast_elementwise
function is dtype specialized
…imized add op" Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned]
Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned]
…imized add op" Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned]
Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned]
Tensor a = tf_a.make( | ||
{2, 2, 3, 5}, | ||
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, | ||
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, | ||
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, | ||
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60}); | ||
Tensor b = tf_a.make( | ||
{2, 1, 3, 5}, | ||
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, | ||
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it would probably be more reviewable to fill these programmatically, such as with std::iota, but certainly not blocking
/*data=*/{2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, | ||
17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, | ||
47, 49, 51, 53, 55, 57, 59, 61, 63, 65, 67, 69, 71, 73, 75, | ||
62, 64, 66, 68, 70, 72, 74, 76, 78, 80, 82, 84, 86, 88, 90}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto programmatic fill
@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…imized add op" Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491814](https://our.internmc.facebook.com/intern/diff/D69491814) [ghstack-poisoned]
Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491814](https://our.internmc.facebook.com/intern/diff/D69491814) [ghstack-poisoned]
…imized add op" Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491814](https://our.internmc.facebook.com/intern/diff/D69491814) [ghstack-poisoned]
Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491814](https://our.internmc.facebook.com/intern/diff/D69491814) [ghstack-poisoned]
@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…imized add op" Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491814](https://our.internmc.facebook.com/intern/diff/D69491814) [ghstack-poisoned]
Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491814](https://our.internmc.facebook.com/intern/diff/D69491814) [ghstack-poisoned]
@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…imized add op" Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491814](https://our.internmc.facebook.com/intern/diff/D69491814) [ghstack-poisoned]
Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491814](https://our.internmc.facebook.com/intern/diff/D69491814) [ghstack-poisoned]
…imized add op" Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491814](https://our.internmc.facebook.com/intern/diff/D69491814) [ghstack-poisoned]
Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491814](https://our.internmc.facebook.com/intern/diff/D69491814) [ghstack-poisoned]
…imized add op" Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491814](https://our.internmc.facebook.com/intern/diff/D69491814) [ghstack-poisoned]
Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491814](https://our.internmc.facebook.com/intern/diff/D69491814) [ghstack-poisoned]
@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Stack from ghstack (oldest at bottom):
Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib
Test Plan:
tests added
Reviewers:
Subscribers:
Tasks:
Tags:
cc @larryliu0820 @manuelcandales
Differential Revision: D69491814