Skip to content

[ExecuTorch] Add broadcast support for optimized add op #8205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Feb 15, 2025

Conversation

kimishpatel
Copy link
Contributor

@kimishpatel kimishpatel commented Feb 5, 2025

Stack from ghstack (oldest at bottom):

Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

cc @larryliu0820 @manuelcandales

Differential Revision: D69491814

Summary:
Refactoring broadcast handling utils that were added for op_mul. This is
in prepartion use these utils to handle broadcast for other ops such as
add, sub, div.
Plus remove a redundant test

Test Plan:
optimized_kernels_test in CI

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Feb 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/8205

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 6f2f01a with merge base 8148603 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

kimishpatel added a commit that referenced this pull request Feb 5, 2025
Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: e4dea30
Pull Request resolved: #8205
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 5, 2025
Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
…imized add op"

Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
@kimishpatel kimishpatel requested a review from swolchok February 6, 2025 06:40
Tensor& handle_last_dim_broadcast_elementwise(
KernelRuntimeContext& ctx,
const Op& vec_fun,
const Tensor& a,
const Tensor& b,
Tensor& out,
const ElementwiseOptimizedPath selected_optimized_path) {
const ElementwiseOptimizedPath selected_optimized_path,
executorch::aten::optional<Scalar>& alpha = {}) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

error messages are telling you this needs to be a const ref. also why is this not std::optional

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah just realized that. not sure why it did not throw error in local build but different compile options i guess.

I just followed what i see elsewhere. Happy to switch to std::optional too which is what I guess is backing that but maybe for aten build it aliases c10:::optional? Let me check that first

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c10::optional is gone

Comment on lines 256 to 257
CTYPE alpha_val;
Vec alpha_val_vec(alpha_val);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

normally I would say "alpha_val needs to be initialized; C++ doesn't have default zero-initialization for primitives", but actually the problem here is that alpha_val needs to move under the if:

Vec alpha_val_vec;
if (alpha.has_value()) {
  CTYPE alpha_val;
  ET_KERNEL_CHECK(...)

Tensor& handle_broadcast_elementwise(
KernelRuntimeContext& ctx,
const Op& vec_fun,
const Tensor& a,
const Tensor& b,
Tensor& out,
const ElementwiseOptimizedPath selected_optimized_path) {
const ElementwiseOptimizedPath selected_optimized_path,
executorch::aten::optional<Scalar> alpha = {}) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this by-value but the other one is a reference? make consistent

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh good call out. my bad

inner_size);
ET_SWITCH_REALB_TYPES(out_type, ctx, internal::BinaryOpTypeName<op_type>::kName, CTYPE, [&]() {
using Vec = executorch::vec::Vectorized<CTYPE>;
CTYPE alpha_val;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same problem as above

// This behavior is a bit confusing.
// Reason we swap out args here is because handle_broadcast_elementwise
// handles this selected_optimized_path option a bit differently.
// This should really be resoled in handle_broadcast_elementwise.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/resoled/resolved/

ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) {
// This behavior is a bit confusing.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what's confusing here; there is an argument that should be scaled by alpha_val, we have to scale the right one. definitely don't think handle_broadcast_elementwise should be coupled to the specific op.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

problem is this. All the reverse arg stuff has specifically different handlking inside handle_broadcast_elementwise. But that handling, inside handle_broadcast_elementwise, to work this change is necessary which makes them coupled and fragile

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess confusing is not the right word here though.

@@ -130,8 +130,12 @@ Tensor& opt_mul_out(
out.numel());
});
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
auto mul_lambda = [](auto x, auto y) { return x * y; };
return torch::executor::handle_broadcast_elementwise(
// Reason for using alpha:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing rest of comment after the colon

return torch::executor::handle_broadcast_elementwise(
// Reason for using alpha:
auto mul_lambda = [](auto x, auto y, auto alpha) {
(void)alpha;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you :)

Comment on lines 61 to 83
template <BinaryOpType op_type>
struct BinaryOpTypeName;

template <>
struct BinaryOpTypeName<BinaryOpType::kAdd> {
static constexpr char kName[] = "add.out";
};

template <>
struct BinaryOpTypeName<BinaryOpType::kSub> {
static constexpr char kName[] = "sub.out";
};

template <>
struct BinaryOpTypeName<BinaryOpType::kMul> {
static constexpr char kName[] = "mul.out";
};

template <>
struct BinaryOpTypeName<BinaryOpType::kDiv> {
static constexpr char kName[] = "div.out";
};

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need to do this. see existing example:

static constexpr const char op_name[] = "rsub.Scalar_out";
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(

the secret sauce is that the string literal has to be a static constexpr const char [] and then you can pass it to a const char* template argument directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I was hoping you would point me to something better for this

@kimishpatel kimishpatel added module: kernels Issues related to kernel libraries and utilities, and code under kernels/ release notes: ops & kernels Changes to the opset and any new / changed kernel implementations labels Feb 7, 2025
…imized add op"

Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

[ghstack-poisoned]
Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

[ghstack-poisoned]
…imized add op"

Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

[ghstack-poisoned]
Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

[ghstack-poisoned]
Copy link
Contributor

@swolchok swolchok left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just re-request review if my suggestion is bad :)

Comment on lines 155 to 156
// creation to handle_broadcast_elementwise and it be aware of which op is
// being executed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure I agree, but we can settle that on a review of a proposed change

Comment on lines 225 to 237
using Vec = executorch::vec::Vectorized<CTYPE>;
Vec alpha_val_vec;
if (alpha.has_value()) {
CTYPE alpha_val;
ET_KERNEL_CHECK(
ctx,
native::utils::extract_scalar(alpha.value(), &alpha_val),
InvalidArgument, );
alpha_val_vec = Vec(alpha_val);
}
auto vec_fun_alpha = [vec_fun, alpha_val_vec](const Vec& a, const Vec& b) {
return vec_fun(a, b, alpha_val_vec);
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I see, you're having problems with the lambda because of this part. you can solve this by factoring the code differently.

the end result at the callsite could look something like

auto broadcast_op_plan_opt = plan_broadcast_elementwise(...); // broadcast_op_plan is a struct containing all the stuff you work out that isn't dependent on the dtype, like lhs, rhs. it does ET_KERNEL_CHECKs intenrally and returns nullopt if they fail.
if (!broadcast_op_plan_opt) {
  // a check already failed
  return;
}
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
  auto alpha_val_vec_opt = extract_scalar_to_vector<CTYPE>(); // wrap up the bit that 
  if (!alpha_val_vec_opt) {
    // awkward that this only returns from the lambda, but this is a generic ET_KERNEL_CHECK problem
    return;
  }
  auto add_lambda = [alpha_val_vec = *alpha_val_vec_opt](auto x, auto y) {
        return y + alpha_val * x;
      };
  execute_broadcast_elementwise_plan<CTYPE>(*broadcast_op_plan_opt, add_lambda, ...);
});

disclaimer: this is off the top of my head and it may be possible to unify some of this stuff with dtype_util.h for further simplification, though dtype_util is mostly intended to cut size/build time of portable ops

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Let me see if I dont run into other issues to enable such refactor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok so I looked refactor required. I think it is doable at the cost of moving ET_SWITCH_REALB_TYPES macros to the callsite in respective ops. Downside here is that now if you enable new dtype for optimized path, you have to change all the callsites.

So I am not fully convinced that it is better go down that route. But want to see whats your reasoning.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you have to change all the callsites.

that's just a matter of typing, right? if you plan to do it (I suppose optimizing Half/BFloat16 should be on our TODO list if the hardware supports the relevant instructions) and you really don't want to change 4-5 files later (you'll have to change them anyway for specifically Half/BFloat16 because there are opt-outs), you could always #define ET_SWITCH_OPTIMIZED_ELEMENTWISE_BROADCAST_OP_TYPES ET_SWITCH_REALB_TYPEs pre-emptively.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok thats fair. But is your reasoning for this change simpler code or you see perf impact.

I am not too stuck to it, so I will just go ahead and do it but wanted to understand your reasoning

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simpler less repetitive code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok will make the change but this will likely marginally increase size since now the whole handle_broadcast_elementwise function is dtype specialized

…imized add op"

Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

[ghstack-poisoned]
Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

[ghstack-poisoned]
@kimishpatel kimishpatel requested a review from swolchok February 11, 2025 00:07
…imized add op"

Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

[ghstack-poisoned]
Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

[ghstack-poisoned]
Comment on lines +142 to +151
Tensor a = tf_a.make(
{2, 2, 3, 5},
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60});
Tensor b = tf_a.make(
{2, 1, 3, 5},
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it would probably be more reviewable to fill these programmatically, such as with std::iota, but certainly not blocking

Comment on lines +157 to +160
/*data=*/{2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30,
17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45,
47, 49, 51, 53, 55, 57, 59, 61, 63, 65, 67, 69, 71, 73, 75,
62, 64, 66, 68, 70, 72, 74, 76, 78, 80, 82, 84, 86, 88, 90});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto programmatic fill

@kimishpatel
Copy link
Contributor Author

@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

…imized add op"

Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491814](https://our.internmc.facebook.com/intern/diff/D69491814)

[ghstack-poisoned]
Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491814](https://our.internmc.facebook.com/intern/diff/D69491814)

[ghstack-poisoned]
…imized add op"

Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491814](https://our.internmc.facebook.com/intern/diff/D69491814)

[ghstack-poisoned]
Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491814](https://our.internmc.facebook.com/intern/diff/D69491814)

[ghstack-poisoned]
@kimishpatel
Copy link
Contributor Author

@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

…imized add op"

Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491814](https://our.internmc.facebook.com/intern/diff/D69491814)

[ghstack-poisoned]
Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491814](https://our.internmc.facebook.com/intern/diff/D69491814)

[ghstack-poisoned]
@kimishpatel
Copy link
Contributor Author

@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@kimishpatel kimishpatel changed the base branch from gh/kimishpatel/154/base to main February 13, 2025 14:37
…imized add op"

Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491814](https://our.internmc.facebook.com/intern/diff/D69491814)

[ghstack-poisoned]
Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491814](https://our.internmc.facebook.com/intern/diff/D69491814)

[ghstack-poisoned]
…imized add op"

Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491814](https://our.internmc.facebook.com/intern/diff/D69491814)

[ghstack-poisoned]
Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491814](https://our.internmc.facebook.com/intern/diff/D69491814)

[ghstack-poisoned]
…imized add op"

Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491814](https://our.internmc.facebook.com/intern/diff/D69491814)

[ghstack-poisoned]
Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491814](https://our.internmc.facebook.com/intern/diff/D69491814)

[ghstack-poisoned]
@kimishpatel
Copy link
Contributor Author

@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@kimishpatel kimishpatel merged commit b71d873 into main Feb 15, 2025
45 of 48 checks passed
@kimishpatel kimishpatel deleted the gh/kimishpatel/154/head branch February 15, 2025 04:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: kernels Issues related to kernel libraries and utilities, and code under kernels/ release notes: ops & kernels Changes to the opset and any new / changed kernel implementations
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants