Skip to content

[Executorch][llama] Add custom_sdpa and use that instead of sdpa_with_kv_cache #5621

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 22 additions & 19 deletions examples/models/llama2/source_transformation/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,28 @@ def forward(
# returns dequantized kv cache
# Not most optimal. Optimizations to follow next
k_cache, v_cache = self.kv_cache.update(input_pos, k, v)
# Note that this path will still inplace mutate the k_cache, v_cache.
# WHen we are not using quantized kv cache, this will just mutate
# the original kv cache.
# When we aer using quantized kv cache, this will mutate
# k_cache, v_cache that is returned from cache update operation.
# This operation just dequantized thee cache and returns that.
# Future diffs will optimize this
output = torch.ops.llama.sdpa_with_kv_cache(
q,
k,
v,
k_cache,
v_cache,
input_pos[-1].item(),
seqlen,
None, # Attention mask
0, # dropout probability. Ignored by the code
True, # is_causal
)
output = torch.ops.llama.custom_sdpa(
q,
k_cache,
v_cache,
input_pos[0].item(),
None, # Attention mask
0, # dropout probability. Ignored by the code
True, # is_causal
)
else:
output = torch.ops.llama.sdpa_with_kv_cache(
q,
k,
v,
k_cache,
v_cache,
input_pos[0].item(),
seqlen,
None, # Attention mask
0, # dropout probability. Ignored by the code
True, # is_causal
)
return output.view(bsz, seqlen, self.dim)


Expand Down
141 changes: 73 additions & 68 deletions extension/llm/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,74 @@ void update_cache(
}
}

} // anonymous namespace

Tensor& flash_attention_kernel_out(
RuntimeContext& ctx,
const Tensor& query,
const Tensor& key,
const Tensor& value,
const optional<Tensor>& attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
Tensor& output) {
(void)ctx;
ET_KERNEL_CHECK(
ctx,
validate_flash_attention_args(query, key, value, attn_mask),
InvalidArgument,
output);

ET_KERNEL_CHECK(
ctx,
resize_tensor(output, query.sizes()) == Error::Ok,
InvalidArgument,
output);

auto q_seq_len = query.size(2);

ET_SWITCH_FLOAT_TYPES(
query.scalar_type(), ctx, "flash_attention", CTYPE, [&] {
// TODO we need to re-evaluate this for ARM CPUs
// And there can be many so instead of templatizing
// we might consider another appraoch
if (q_seq_len >= 768) {
cpu_flash_attention<CTYPE, 256, 512>(
output,
query,
key,
value,
dropout_p,
is_causal,
attn_mask,
scale);
} else if (q_seq_len >= 192) {
cpu_flash_attention<CTYPE, 64, 512>(
output,
query,
key,
value,
dropout_p,
is_causal,
attn_mask,
scale);
} else {
cpu_flash_attention<CTYPE, 32, 512>(
output,
query,
key,
value,
dropout_p,
is_causal,
attn_mask,
scale);
}
});
return output;
}

/*
Input params
@param[in] q_projected Projected query with query weights.
Expand Down Expand Up @@ -900,74 +968,6 @@ Tensor& custom_sdpa_out(
});
return output;
}
} // anonymous namespace

Tensor& flash_attention_kernel_out(
KernelRuntimeContext& ctx,
const Tensor& query,
const Tensor& key,
const Tensor& value,
const optional<Tensor>& attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
Tensor& output) {
(void)ctx;
ET_KERNEL_CHECK(
ctx,
validate_flash_attention_args(query, key, value, attn_mask),
InvalidArgument,
output);

ET_KERNEL_CHECK(
ctx,
resize_tensor(output, query.sizes()) == Error::Ok,
InvalidArgument,
output);

auto q_seq_len = query.size(2);

ET_SWITCH_FLOAT_TYPES(
query.scalar_type(), ctx, "flash_attention", CTYPE, [&] {
// TODO we need to re-evaluate this for ARM CPUs
// And there can be many so instead of templatizing
// we might consider another appraoch
if (q_seq_len >= 768) {
cpu_flash_attention<CTYPE, 256, 512>(
output,
query,
key,
value,
dropout_p,
is_causal,
attn_mask,
scale);
} else if (q_seq_len >= 192) {
cpu_flash_attention<CTYPE, 64, 512>(
output,
query,
key,
value,
dropout_p,
is_causal,
attn_mask,
scale);
} else {
cpu_flash_attention<CTYPE, 32, 512>(
output,
query,
key,
value,
dropout_p,
is_causal,
attn_mask,
scale);
}
});
return output;
}

/*
Input params
@param[in] q_projected Projected query with query weights.
Expand Down Expand Up @@ -1033,3 +1033,8 @@ EXECUTORCH_LIBRARY(
llama,
"sdpa_with_kv_cache.out",
torch::executor::native::sdpa_with_kv_cache_out);

EXECUTORCH_LIBRARY(
llama,
"custom_sdpa.out",
torch::executor::native::custom_sdpa_out);
13 changes: 13 additions & 0 deletions extension/llm/custom_ops/op_sdpa.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,19 @@ Tensor& sdpa_with_kv_cache_out(
const optional<double> scale,
Tensor& output);

Tensor& custom_sdpa_out(
RuntimeContext& ctx,
const Tensor& q,
const Tensor& k,
const Tensor& v,
const int64_t start_pos,
const optional<Tensor>& attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
Tensor& output);

Tensor& flash_attention_kernel_out(
KernelRuntimeContext& ctx,
const Tensor& query,
Expand Down
62 changes: 58 additions & 4 deletions extension/llm/custom_ops/op_sdpa_aot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,51 @@ at::Tensor sdpa_with_kv_cache_aten(
return output;
}

Tensor& custom_sdpa_out_no_context(
const Tensor& q,
const Tensor& k,
const Tensor& v,
const int64_t start_pos,
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<Tensor> attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
Tensor& output) {
exec_aten::RuntimeContext context{};
return torch::executor::native::custom_sdpa_out(
context,
q,
k,
v,
start_pos,
attn_mask,
dropout_p,
is_causal,
scale,
output);
}

at::Tensor custom_sdpa_aten(
const at::Tensor& q,
const at::Tensor& k,
const at::Tensor& v,
const int64_t start_pos,
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const c10::optional<at::Tensor> attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const c10::optional<double> scale) {
auto output = at::empty_like(q);
WRAP_TO_ATEN(custom_sdpa_out_no_context, 8)
(q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output);
return output;
}

Tensor& update_quantized_cache_out_no_context(
const Tensor& value,
Tensor& cache,
Expand Down Expand Up @@ -115,6 +160,14 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
"sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "
"float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)");
m.def(
"custom_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
"float? scale=None) -> Tensor");
m.def(
"custom_sdpa.out(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
"float? scale=None, *, Tensor(a!) out) -> Tensor(a!)");
m.def(
"update_quantized_cache(Tensor value, Tensor(a!) cache, "
"SymInt start_pos) -> Tensor");
Expand All @@ -123,17 +176,18 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
"SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)");
}

// TODO: Rename this file to op_custom_ops_aot.cpp
TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
m.impl(
"sdpa_with_kv_cache", torch::executor::native::sdpa_with_kv_cache_aten);
m.impl(
"sdpa_with_kv_cache.out",
WRAP_TO_ATEN(
torch::executor::native::sdpa_with_kv_cache_out_no_context, 11));
}

// TODO: Rename this file to op_custom_ops_aot.cpp
TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
m.impl("custom_sdpa", torch::executor::native::custom_sdpa_aten);
m.impl(
"custom_sdpa.out",
WRAP_TO_ATEN(torch::executor::native::custom_sdpa_out_no_context, 8));
m.impl(
"update_quantized_cache",
torch::executor::native::update_quantized_cache_aten);
Expand Down
29 changes: 29 additions & 0 deletions extension/llm/custom_ops/sdpa_with_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,35 @@ def fast_hadamard_transform_meta(mat):
return torch.empty_like(mat)


@impl(custom_ops_lib, "custom_sdpa", "Meta")
def custom_sdpa(
query,
key_cache,
value_cache,
start_pos,
attn_mask=None,
drpout_p=0.0,
is_causal=False,
scale=None,
):
seq_len = query.size(1)
_validate_params(
query,
key_cache,
value_cache,
key_cache,
value_cache,
start_pos,
seq_len,
attn_mask,
drpout_p,
is_causal,
scale,
)

return torch.empty_like(query)


def _validate_update_cache_params(
value,
cache,
Expand Down
Loading