diff --git a/examples/models/llama2/source_transformation/sdpa.py b/examples/models/llama2/source_transformation/sdpa.py index 263a98a66b3..54184f89dbc 100644 --- a/examples/models/llama2/source_transformation/sdpa.py +++ b/examples/models/llama2/source_transformation/sdpa.py @@ -46,25 +46,28 @@ def forward( # returns dequantized kv cache # Not most optimal. Optimizations to follow next k_cache, v_cache = self.kv_cache.update(input_pos, k, v) - # Note that this path will still inplace mutate the k_cache, v_cache. - # WHen we are not using quantized kv cache, this will just mutate - # the original kv cache. - # When we aer using quantized kv cache, this will mutate - # k_cache, v_cache that is returned from cache update operation. - # This operation just dequantized thee cache and returns that. - # Future diffs will optimize this - output = torch.ops.llama.sdpa_with_kv_cache( - q, - k, - v, - k_cache, - v_cache, - input_pos[-1].item(), - seqlen, - None, # Attention mask - 0, # dropout probability. Ignored by the code - True, # is_causal - ) + output = torch.ops.llama.custom_sdpa( + q, + k_cache, + v_cache, + input_pos[0].item(), + None, # Attention mask + 0, # dropout probability. Ignored by the code + True, # is_causal + ) + else: + output = torch.ops.llama.sdpa_with_kv_cache( + q, + k, + v, + k_cache, + v_cache, + input_pos[0].item(), + seqlen, + None, # Attention mask + 0, # dropout probability. Ignored by the code + True, # is_causal + ) return output.view(bsz, seqlen, self.dim) diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index b5cb2b55e99..4316a68afa4 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -754,6 +754,74 @@ void update_cache( } } +} // anonymous namespace + +Tensor& flash_attention_kernel_out( + RuntimeContext& ctx, + const Tensor& query, + const Tensor& key, + const Tensor& value, + const optional& attn_mask, + const double dropout_p, + const bool is_causal, + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const optional scale, + Tensor& output) { + (void)ctx; + ET_KERNEL_CHECK( + ctx, + validate_flash_attention_args(query, key, value, attn_mask), + InvalidArgument, + output); + + ET_KERNEL_CHECK( + ctx, + resize_tensor(output, query.sizes()) == Error::Ok, + InvalidArgument, + output); + + auto q_seq_len = query.size(2); + + ET_SWITCH_FLOAT_TYPES( + query.scalar_type(), ctx, "flash_attention", CTYPE, [&] { + // TODO we need to re-evaluate this for ARM CPUs + // And there can be many so instead of templatizing + // we might consider another appraoch + if (q_seq_len >= 768) { + cpu_flash_attention( + output, + query, + key, + value, + dropout_p, + is_causal, + attn_mask, + scale); + } else if (q_seq_len >= 192) { + cpu_flash_attention( + output, + query, + key, + value, + dropout_p, + is_causal, + attn_mask, + scale); + } else { + cpu_flash_attention( + output, + query, + key, + value, + dropout_p, + is_causal, + attn_mask, + scale); + } + }); + return output; +} + /* Input params @param[in] q_projected Projected query with query weights. @@ -900,74 +968,6 @@ Tensor& custom_sdpa_out( }); return output; } -} // anonymous namespace - -Tensor& flash_attention_kernel_out( - KernelRuntimeContext& ctx, - const Tensor& query, - const Tensor& key, - const Tensor& value, - const optional& attn_mask, - const double dropout_p, - const bool is_causal, - // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy - const optional scale, - Tensor& output) { - (void)ctx; - ET_KERNEL_CHECK( - ctx, - validate_flash_attention_args(query, key, value, attn_mask), - InvalidArgument, - output); - - ET_KERNEL_CHECK( - ctx, - resize_tensor(output, query.sizes()) == Error::Ok, - InvalidArgument, - output); - - auto q_seq_len = query.size(2); - - ET_SWITCH_FLOAT_TYPES( - query.scalar_type(), ctx, "flash_attention", CTYPE, [&] { - // TODO we need to re-evaluate this for ARM CPUs - // And there can be many so instead of templatizing - // we might consider another appraoch - if (q_seq_len >= 768) { - cpu_flash_attention( - output, - query, - key, - value, - dropout_p, - is_causal, - attn_mask, - scale); - } else if (q_seq_len >= 192) { - cpu_flash_attention( - output, - query, - key, - value, - dropout_p, - is_causal, - attn_mask, - scale); - } else { - cpu_flash_attention( - output, - query, - key, - value, - dropout_p, - is_causal, - attn_mask, - scale); - } - }); - return output; -} - /* Input params @param[in] q_projected Projected query with query weights. @@ -1033,3 +1033,8 @@ EXECUTORCH_LIBRARY( llama, "sdpa_with_kv_cache.out", torch::executor::native::sdpa_with_kv_cache_out); + +EXECUTORCH_LIBRARY( + llama, + "custom_sdpa.out", + torch::executor::native::custom_sdpa_out); diff --git a/extension/llm/custom_ops/op_sdpa.h b/extension/llm/custom_ops/op_sdpa.h index ce969b013d2..bc2202b9bd8 100644 --- a/extension/llm/custom_ops/op_sdpa.h +++ b/extension/llm/custom_ops/op_sdpa.h @@ -31,6 +31,19 @@ Tensor& sdpa_with_kv_cache_out( const optional scale, Tensor& output); +Tensor& custom_sdpa_out( + RuntimeContext& ctx, + const Tensor& q, + const Tensor& k, + const Tensor& v, + const int64_t start_pos, + const optional& attn_mask, + const double dropout_p, + const bool is_causal, + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const optional scale, + Tensor& output); + Tensor& flash_attention_kernel_out( KernelRuntimeContext& ctx, const Tensor& query, diff --git a/extension/llm/custom_ops/op_sdpa_aot.cpp b/extension/llm/custom_ops/op_sdpa_aot.cpp index f3674088fd7..c182903aa54 100644 --- a/extension/llm/custom_ops/op_sdpa_aot.cpp +++ b/extension/llm/custom_ops/op_sdpa_aot.cpp @@ -82,6 +82,51 @@ at::Tensor sdpa_with_kv_cache_aten( return output; } +Tensor& custom_sdpa_out_no_context( + const Tensor& q, + const Tensor& k, + const Tensor& v, + const int64_t start_pos, + // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const optional attn_mask, + const double dropout_p, + const bool is_causal, + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const optional scale, + Tensor& output) { + exec_aten::RuntimeContext context{}; + return torch::executor::native::custom_sdpa_out( + context, + q, + k, + v, + start_pos, + attn_mask, + dropout_p, + is_causal, + scale, + output); +} + +at::Tensor custom_sdpa_aten( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const int64_t start_pos, + // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const c10::optional attn_mask, + const double dropout_p, + const bool is_causal, + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const c10::optional scale) { + auto output = at::empty_like(q); + WRAP_TO_ATEN(custom_sdpa_out_no_context, 8) + (q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output); + return output; +} + Tensor& update_quantized_cache_out_no_context( const Tensor& value, Tensor& cache, @@ -115,6 +160,14 @@ TORCH_LIBRARY_FRAGMENT(llama, m) { "sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, " "Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, " "float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)"); + m.def( + "custom_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, " + "Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, " + "float? scale=None) -> Tensor"); + m.def( + "custom_sdpa.out(Tensor query, Tensor key, Tensor value, SymInt start_pos, " + "Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, " + "float? scale=None, *, Tensor(a!) out) -> Tensor(a!)"); m.def( "update_quantized_cache(Tensor value, Tensor(a!) cache, " "SymInt start_pos) -> Tensor"); @@ -123,6 +176,7 @@ TORCH_LIBRARY_FRAGMENT(llama, m) { "SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)"); } +// TODO: Rename this file to op_custom_ops_aot.cpp TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { m.impl( "sdpa_with_kv_cache", torch::executor::native::sdpa_with_kv_cache_aten); @@ -130,10 +184,10 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { "sdpa_with_kv_cache.out", WRAP_TO_ATEN( torch::executor::native::sdpa_with_kv_cache_out_no_context, 11)); -} - -// TODO: Rename this file to op_custom_ops_aot.cpp -TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { + m.impl("custom_sdpa", torch::executor::native::custom_sdpa_aten); + m.impl( + "custom_sdpa.out", + WRAP_TO_ATEN(torch::executor::native::custom_sdpa_out_no_context, 8)); m.impl( "update_quantized_cache", torch::executor::native::update_quantized_cache_aten); diff --git a/extension/llm/custom_ops/sdpa_with_kv_cache.py b/extension/llm/custom_ops/sdpa_with_kv_cache.py index d6c7fbab6f4..85021266b59 100644 --- a/extension/llm/custom_ops/sdpa_with_kv_cache.py +++ b/extension/llm/custom_ops/sdpa_with_kv_cache.py @@ -141,6 +141,35 @@ def fast_hadamard_transform_meta(mat): return torch.empty_like(mat) +@impl(custom_ops_lib, "custom_sdpa", "Meta") +def custom_sdpa( + query, + key_cache, + value_cache, + start_pos, + attn_mask=None, + drpout_p=0.0, + is_causal=False, + scale=None, +): + seq_len = query.size(1) + _validate_params( + query, + key_cache, + value_cache, + key_cache, + value_cache, + start_pos, + seq_len, + attn_mask, + drpout_p, + is_causal, + scale, + ) + + return torch.empty_like(query) + + def _validate_update_cache_params( value, cache,