Skip to content

Commit 44ed2b2

Browse files
committed
[Executorch][llama] Add custom_sdpa and use that instead of sdpa_with_kv_cache
sdpa_with_kv_cache updates kv cache. In quantized kv cache, cache updates happens separately. Then the quantized cache is dequantized. After that we call sdpa_with_kv_cache which copies k and v data into dequantized cache. Although this is not needed because the actual cache is the one that is quantized. For very large context length this will add significant amount data copy. Subsequent diffs will deprecate sdpa_with_kv_cache op and deconstruct that using a) update_cache op and b) custom_sdpa op. Differential Revision: [D62623241](https://our.internmc.facebook.com/intern/diff/D62623241/) ghstack-source-id: 244549445 Pull Request resolved: #5621
1 parent a69818c commit 44ed2b2

File tree

5 files changed

+195
-91
lines changed

5 files changed

+195
-91
lines changed

examples/models/llama2/source_transformation/sdpa.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -46,25 +46,28 @@ def forward(
4646
# returns dequantized kv cache
4747
# Not most optimal. Optimizations to follow next
4848
k_cache, v_cache = self.kv_cache.update(input_pos, k, v)
49-
# Note that this path will still inplace mutate the k_cache, v_cache.
50-
# WHen we are not using quantized kv cache, this will just mutate
51-
# the original kv cache.
52-
# When we aer using quantized kv cache, this will mutate
53-
# k_cache, v_cache that is returned from cache update operation.
54-
# This operation just dequantized thee cache and returns that.
55-
# Future diffs will optimize this
56-
output = torch.ops.llama.sdpa_with_kv_cache(
57-
q,
58-
k,
59-
v,
60-
k_cache,
61-
v_cache,
62-
input_pos[-1].item(),
63-
seqlen,
64-
None, # Attention mask
65-
0, # dropout probability. Ignored by the code
66-
True, # is_causal
67-
)
49+
output = torch.ops.llama.custom_sdpa(
50+
q,
51+
k_cache,
52+
v_cache,
53+
input_pos[0].item(),
54+
None, # Attention mask
55+
0, # dropout probability. Ignored by the code
56+
True, # is_causal
57+
)
58+
else:
59+
output = torch.ops.llama.sdpa_with_kv_cache(
60+
q,
61+
k,
62+
v,
63+
k_cache,
64+
v_cache,
65+
input_pos[0].item(),
66+
seqlen,
67+
None, # Attention mask
68+
0, # dropout probability. Ignored by the code
69+
True, # is_causal
70+
)
6871
return output.view(bsz, seqlen, self.dim)
6972

7073

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 73 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,74 @@ void update_cache(
754754
}
755755
}
756756

757+
} // anonymous namespace
758+
759+
Tensor& flash_attention_kernel_out(
760+
RuntimeContext& ctx,
761+
const Tensor& query,
762+
const Tensor& key,
763+
const Tensor& value,
764+
const optional<Tensor>& attn_mask,
765+
const double dropout_p,
766+
const bool is_causal,
767+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
768+
const optional<double> scale,
769+
Tensor& output) {
770+
(void)ctx;
771+
ET_KERNEL_CHECK(
772+
ctx,
773+
validate_flash_attention_args(query, key, value, attn_mask),
774+
InvalidArgument,
775+
output);
776+
777+
ET_KERNEL_CHECK(
778+
ctx,
779+
resize_tensor(output, query.sizes()) == Error::Ok,
780+
InvalidArgument,
781+
output);
782+
783+
auto q_seq_len = query.size(2);
784+
785+
ET_SWITCH_FLOAT_TYPES(
786+
query.scalar_type(), ctx, "flash_attention", CTYPE, [&] {
787+
// TODO we need to re-evaluate this for ARM CPUs
788+
// And there can be many so instead of templatizing
789+
// we might consider another appraoch
790+
if (q_seq_len >= 768) {
791+
cpu_flash_attention<CTYPE, 256, 512>(
792+
output,
793+
query,
794+
key,
795+
value,
796+
dropout_p,
797+
is_causal,
798+
attn_mask,
799+
scale);
800+
} else if (q_seq_len >= 192) {
801+
cpu_flash_attention<CTYPE, 64, 512>(
802+
output,
803+
query,
804+
key,
805+
value,
806+
dropout_p,
807+
is_causal,
808+
attn_mask,
809+
scale);
810+
} else {
811+
cpu_flash_attention<CTYPE, 32, 512>(
812+
output,
813+
query,
814+
key,
815+
value,
816+
dropout_p,
817+
is_causal,
818+
attn_mask,
819+
scale);
820+
}
821+
});
822+
return output;
823+
}
824+
757825
/*
758826
Input params
759827
@param[in] q_projected Projected query with query weights.
@@ -900,74 +968,6 @@ Tensor& custom_sdpa_out(
900968
});
901969
return output;
902970
}
903-
} // anonymous namespace
904-
905-
Tensor& flash_attention_kernel_out(
906-
KernelRuntimeContext& ctx,
907-
const Tensor& query,
908-
const Tensor& key,
909-
const Tensor& value,
910-
const optional<Tensor>& attn_mask,
911-
const double dropout_p,
912-
const bool is_causal,
913-
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
914-
const optional<double> scale,
915-
Tensor& output) {
916-
(void)ctx;
917-
ET_KERNEL_CHECK(
918-
ctx,
919-
validate_flash_attention_args(query, key, value, attn_mask),
920-
InvalidArgument,
921-
output);
922-
923-
ET_KERNEL_CHECK(
924-
ctx,
925-
resize_tensor(output, query.sizes()) == Error::Ok,
926-
InvalidArgument,
927-
output);
928-
929-
auto q_seq_len = query.size(2);
930-
931-
ET_SWITCH_FLOAT_TYPES(
932-
query.scalar_type(), ctx, "flash_attention", CTYPE, [&] {
933-
// TODO we need to re-evaluate this for ARM CPUs
934-
// And there can be many so instead of templatizing
935-
// we might consider another appraoch
936-
if (q_seq_len >= 768) {
937-
cpu_flash_attention<CTYPE, 256, 512>(
938-
output,
939-
query,
940-
key,
941-
value,
942-
dropout_p,
943-
is_causal,
944-
attn_mask,
945-
scale);
946-
} else if (q_seq_len >= 192) {
947-
cpu_flash_attention<CTYPE, 64, 512>(
948-
output,
949-
query,
950-
key,
951-
value,
952-
dropout_p,
953-
is_causal,
954-
attn_mask,
955-
scale);
956-
} else {
957-
cpu_flash_attention<CTYPE, 32, 512>(
958-
output,
959-
query,
960-
key,
961-
value,
962-
dropout_p,
963-
is_causal,
964-
attn_mask,
965-
scale);
966-
}
967-
});
968-
return output;
969-
}
970-
971971
/*
972972
Input params
973973
@param[in] q_projected Projected query with query weights.
@@ -1033,3 +1033,8 @@ EXECUTORCH_LIBRARY(
10331033
llama,
10341034
"sdpa_with_kv_cache.out",
10351035
torch::executor::native::sdpa_with_kv_cache_out);
1036+
1037+
EXECUTORCH_LIBRARY(
1038+
llama,
1039+
"custom_sdpa.out",
1040+
torch::executor::native::custom_sdpa_out);

extension/llm/custom_ops/op_sdpa.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,19 @@ Tensor& sdpa_with_kv_cache_out(
3131
const optional<double> scale,
3232
Tensor& output);
3333

34+
Tensor& custom_sdpa_out(
35+
RuntimeContext& ctx,
36+
const Tensor& q,
37+
const Tensor& k,
38+
const Tensor& v,
39+
const int64_t start_pos,
40+
const optional<Tensor>& attn_mask,
41+
const double dropout_p,
42+
const bool is_causal,
43+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
44+
const optional<double> scale,
45+
Tensor& output);
46+
3447
Tensor& flash_attention_kernel_out(
3548
KernelRuntimeContext& ctx,
3649
const Tensor& query,

extension/llm/custom_ops/op_sdpa_aot.cpp

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,51 @@ at::Tensor sdpa_with_kv_cache_aten(
8282
return output;
8383
}
8484

85+
Tensor& custom_sdpa_out_no_context(
86+
const Tensor& q,
87+
const Tensor& k,
88+
const Tensor& v,
89+
const int64_t start_pos,
90+
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
91+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
92+
const optional<Tensor> attn_mask,
93+
const double dropout_p,
94+
const bool is_causal,
95+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
96+
const optional<double> scale,
97+
Tensor& output) {
98+
exec_aten::RuntimeContext context{};
99+
return torch::executor::native::custom_sdpa_out(
100+
context,
101+
q,
102+
k,
103+
v,
104+
start_pos,
105+
attn_mask,
106+
dropout_p,
107+
is_causal,
108+
scale,
109+
output);
110+
}
111+
112+
at::Tensor custom_sdpa_aten(
113+
const at::Tensor& q,
114+
const at::Tensor& k,
115+
const at::Tensor& v,
116+
const int64_t start_pos,
117+
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
118+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
119+
const c10::optional<at::Tensor> attn_mask,
120+
const double dropout_p,
121+
const bool is_causal,
122+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
123+
const c10::optional<double> scale) {
124+
auto output = at::empty_like(q);
125+
WRAP_TO_ATEN(custom_sdpa_out_no_context, 8)
126+
(q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output);
127+
return output;
128+
}
129+
85130
Tensor& update_quantized_cache_out_no_context(
86131
const Tensor& value,
87132
Tensor& cache,
@@ -115,6 +160,14 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
115160
"sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
116161
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "
117162
"float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)");
163+
m.def(
164+
"custom_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
165+
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
166+
"float? scale=None) -> Tensor");
167+
m.def(
168+
"custom_sdpa.out(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
169+
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
170+
"float? scale=None, *, Tensor(a!) out) -> Tensor(a!)");
118171
m.def(
119172
"update_quantized_cache(Tensor value, Tensor(a!) cache, "
120173
"SymInt start_pos) -> Tensor");
@@ -123,17 +176,18 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
123176
"SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)");
124177
}
125178

179+
// TODO: Rename this file to op_custom_ops_aot.cpp
126180
TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
127181
m.impl(
128182
"sdpa_with_kv_cache", torch::executor::native::sdpa_with_kv_cache_aten);
129183
m.impl(
130184
"sdpa_with_kv_cache.out",
131185
WRAP_TO_ATEN(
132186
torch::executor::native::sdpa_with_kv_cache_out_no_context, 11));
133-
}
134-
135-
// TODO: Rename this file to op_custom_ops_aot.cpp
136-
TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
187+
m.impl("custom_sdpa", torch::executor::native::custom_sdpa_aten);
188+
m.impl(
189+
"custom_sdpa.out",
190+
WRAP_TO_ATEN(torch::executor::native::custom_sdpa_out_no_context, 8));
137191
m.impl(
138192
"update_quantized_cache",
139193
torch::executor::native::update_quantized_cache_aten);

extension/llm/custom_ops/sdpa_with_kv_cache.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,35 @@ def fast_hadamard_transform_meta(mat):
141141
return torch.empty_like(mat)
142142

143143

144+
@impl(custom_ops_lib, "custom_sdpa", "Meta")
145+
def custom_sdpa(
146+
query,
147+
key_cache,
148+
value_cache,
149+
start_pos,
150+
attn_mask=None,
151+
drpout_p=0.0,
152+
is_causal=False,
153+
scale=None,
154+
):
155+
seq_len = query.size(1)
156+
_validate_params(
157+
query,
158+
key_cache,
159+
value_cache,
160+
key_cache,
161+
value_cache,
162+
start_pos,
163+
seq_len,
164+
attn_mask,
165+
drpout_p,
166+
is_causal,
167+
scale,
168+
)
169+
170+
return torch.empty_like(query)
171+
172+
144173
def _validate_update_cache_params(
145174
value,
146175
cache,

0 commit comments

Comments
 (0)