Skip to content

Commit 73bf0b5

Browse files
committed
Refactor custom SDPA op to separate kv cache update from the custom sdpa op
Differential Revision: [D62301837](https://our.internmc.facebook.com/intern/diff/D62301837/) [ghstack-poisoned]
1 parent dd678eb commit 73bf0b5

File tree

1 file changed

+156
-115
lines changed

1 file changed

+156
-115
lines changed

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 156 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,152 @@ void update_cache(
754754
}
755755
}
756756

757+
/*
758+
Input params
759+
@param[in] q_projected Projected query with query weights.
760+
Format [n_layers, batch size, seq_len, num heads, head dim]
761+
@param[in] k_projected Projected query with key weights.
762+
Format [n_layers, batch size, seq_len, num heads, head dim]
763+
@param[in] v_projected Projected query with value weights.
764+
Format [n_layers, batch size, seq_len, num heads, head dim]
765+
@param[in] key_cache Cache of previous k_projected.
766+
Format [n_layers, batch size, max_seq_len, num heads, head dim]
767+
@param[in] key_cache Cache of previous v_projected.
768+
Format [n_layers, batch size, max_seq_len, num heads, head dim]
769+
....
770+
@param[in] start_pos: sequence position
771+
@param[in] seq_len: Seq length. e.g. seq_len dim of q_projected.
772+
*/
773+
Tensor& custom_sdpa_out(
774+
RuntimeContext& ctx,
775+
const Tensor& q,
776+
const Tensor& k,
777+
const Tensor& v,
778+
const int64_t start_pos,
779+
const int64_t seq_len,
780+
const optional<Tensor>& attn_mask,
781+
const double dropout_p,
782+
const bool is_causal,
783+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
784+
const optional<double> scale,
785+
Tensor& output) {
786+
ET_KERNEL_CHECK_MSG(
787+
ctx,
788+
!attn_mask.has_value() || !is_causal,
789+
InvalidArgument,
790+
output,
791+
"attn_mask and is_causal cannot be set at the same time");
792+
793+
ET_CHECK_MSG(q.dim() == 4, "query must be a 4D tensor");
794+
795+
auto q_seq_len = q.size(1);
796+
797+
// Refactor the following into create_view util perhaps using
798+
// TensorPtr
799+
std::array<exec_aten::DimOrderType, util::kKVDim> sliced_key_dim_order{
800+
0, 1, 2, 3};
801+
std::array<exec_aten::SizesType, util::kKVDim> sliced_key_sizes;
802+
sliced_key_sizes[0] = k.size(0);
803+
sliced_key_sizes[1] = start_pos + seq_len; // key_cache.size(2);
804+
sliced_key_sizes[2] = k.size(2);
805+
sliced_key_sizes[3] = k.size(3);
806+
std::array<exec_aten::StridesType, util::kKVDim> sliced_key_strides;
807+
dim_order_to_stride_nocheck(
808+
sliced_key_sizes.data(),
809+
sliced_key_dim_order.data(),
810+
util::kKVDim,
811+
sliced_key_strides.data());
812+
// since the cache is sliced, the batch stride needs to stay the same.
813+
sliced_key_strides[0] = k.strides()[0];
814+
void* key_cache_data = k.mutable_data_ptr();
815+
TensorImpl k_impl = TensorImpl(
816+
k.scalar_type(),
817+
util::kKVDim,
818+
sliced_key_sizes.data(),
819+
key_cache_data,
820+
sliced_key_dim_order.data(),
821+
sliced_key_strides.data(),
822+
TensorShapeDynamism::STATIC);
823+
Tensor sliced_key_cache(&k_impl);
824+
825+
std::array<exec_aten::DimOrderType, util::kKVDim> sliced_value_dim_order{
826+
0, 1, 2, 3};
827+
std::array<exec_aten::SizesType, util::kKVDim> sliced_value_sizes;
828+
sliced_value_sizes[0] = v.size(0);
829+
sliced_value_sizes[1] = start_pos + seq_len; // value_cache.size(2);
830+
sliced_value_sizes[2] = v.size(2);
831+
sliced_value_sizes[3] = v.size(3);
832+
std::array<exec_aten::StridesType, util::kKVDim> sliced_value_strides;
833+
dim_order_to_stride_nocheck(
834+
sliced_value_sizes.data(),
835+
sliced_value_dim_order.data(),
836+
util::kKVDim,
837+
sliced_value_strides.data());
838+
// since the cache is sliced, the batch stride needs to stay the same.
839+
sliced_value_strides[0] = v.strides()[0];
840+
void* value_cache_data = v.mutable_data_ptr();
841+
TensorImpl value_impl = TensorImpl(
842+
v.scalar_type(),
843+
util::kKVDim,
844+
sliced_value_sizes.data(),
845+
value_cache_data,
846+
sliced_value_dim_order.data(),
847+
sliced_value_strides.data(),
848+
TensorShapeDynamism::STATIC);
849+
Tensor sliced_value_cache(&value_impl);
850+
851+
ET_KERNEL_CHECK(
852+
ctx,
853+
resize_tensor(output, q.sizes()) == Error::Ok,
854+
InvalidArgument,
855+
output);
856+
857+
// TODO(task): replace the template param selection logic
858+
// with whatever apprpriately makes more sense for
859+
ET_SWITCH_FLOAT_TYPES(q.scalar_type(), ctx, "flash_attention", CTYPE, [&] {
860+
// TODO we need to re-evaluate this for ARM CPUs
861+
// And there can be many so instead of templatizing
862+
// we might consider another appraoch
863+
if (q_seq_len >= 768) {
864+
cpu_flash_attention<CTYPE, 256, 512>(
865+
output,
866+
q,
867+
sliced_key_cache,
868+
sliced_value_cache,
869+
dropout_p,
870+
is_causal,
871+
attn_mask,
872+
scale,
873+
true,
874+
start_pos);
875+
} else if (q_seq_len >= 192) {
876+
cpu_flash_attention<CTYPE, 64, 512>(
877+
output,
878+
q,
879+
sliced_key_cache,
880+
sliced_value_cache,
881+
dropout_p,
882+
is_causal,
883+
attn_mask,
884+
scale,
885+
true,
886+
start_pos);
887+
} else {
888+
cpu_flash_attention<CTYPE, 32, 512>(
889+
output,
890+
q,
891+
sliced_key_cache,
892+
sliced_value_cache,
893+
dropout_p,
894+
is_causal,
895+
attn_mask,
896+
scale,
897+
true,
898+
start_pos);
899+
}
900+
});
901+
return output;
902+
}
757903
} // anonymous namespace
758904

759905
Tensor& flash_attention_kernel_out(
@@ -860,129 +1006,24 @@ Tensor& sdpa_with_kv_cache_out(
8601006
InvalidArgument,
8611007
output);
8621008

863-
ET_KERNEL_CHECK_MSG(
864-
ctx,
865-
!attn_mask.has_value() || !is_causal,
866-
InvalidArgument,
867-
output,
868-
"attn_mask and is_causal cannot be set at the same time");
869-
8701009
ET_CHECK_MSG(q_projected.dim() == 4, "query must be a 4D tensor");
8711010

8721011
update_cache(k_projected, key_cache, start_pos, seq_len);
8731012
update_cache(v_projected, value_cache, start_pos, seq_len);
8741013

875-
auto q_seq_len = q_projected.size(1);
876-
877-
std::array<exec_aten::DimOrderType, util::kKVDim> sliced_key_dim_order{
878-
0, 1, 2, 3};
879-
std::array<exec_aten::SizesType, util::kKVDim> sliced_key_sizes;
880-
sliced_key_sizes[0] = key_cache.size(0);
881-
sliced_key_sizes[1] = start_pos + seq_len; // key_cache.size(2);
882-
sliced_key_sizes[2] = key_cache.size(2);
883-
sliced_key_sizes[3] = key_cache.size(3);
884-
std::array<exec_aten::StridesType, util::kKVDim> sliced_key_strides;
885-
dim_order_to_stride_nocheck(
886-
sliced_key_sizes.data(),
887-
sliced_key_dim_order.data(),
888-
util::kKVDim,
889-
sliced_key_strides.data());
890-
// since the cache is sliced, the batch stride needs to stay the same.
891-
sliced_key_strides[0] = key_cache.strides()[0];
892-
void* key_cache_data = key_cache.mutable_data_ptr();
893-
TensorImpl k_impl = TensorImpl(
894-
key_cache.scalar_type(),
895-
util::kKVDim,
896-
sliced_key_sizes.data(),
897-
key_cache_data,
898-
sliced_key_dim_order.data(),
899-
sliced_key_strides.data(),
900-
TensorShapeDynamism::STATIC);
901-
Tensor sliced_key_cache(&k_impl);
902-
903-
std::array<exec_aten::DimOrderType, util::kKVDim> sliced_value_dim_order{
904-
0, 1, 2, 3};
905-
std::array<exec_aten::SizesType, util::kKVDim> sliced_value_sizes;
906-
sliced_value_sizes[0] = value_cache.size(0);
907-
sliced_value_sizes[1] = start_pos + seq_len; // value_cache.size(2);
908-
sliced_value_sizes[2] = value_cache.size(2);
909-
sliced_value_sizes[3] = value_cache.size(3);
910-
std::array<exec_aten::StridesType, util::kKVDim> sliced_value_strides;
911-
dim_order_to_stride_nocheck(
912-
sliced_value_sizes.data(),
913-
sliced_value_dim_order.data(),
914-
util::kKVDim,
915-
sliced_value_strides.data());
916-
// since the cache is sliced, the batch stride needs to stay the same.
917-
sliced_value_strides[0] = value_cache.strides()[0];
918-
void* value_cache_data = value_cache.mutable_data_ptr();
919-
TensorImpl value_impl = TensorImpl(
920-
value_cache.scalar_type(),
921-
util::kKVDim,
922-
sliced_value_sizes.data(),
923-
value_cache_data,
924-
sliced_value_dim_order.data(),
925-
sliced_value_strides.data(),
926-
TensorShapeDynamism::STATIC);
927-
Tensor sliced_value_cache(&value_impl);
928-
929-
// Is this true?
930-
// Cant do this as is because the expectation of this kernel is
931-
// that q, k, v are [B, num heads, seq length, head dim]
932-
// and the cache is [B, max seq len, num heads, head dim]
933-
// and q, k, v are all [B, seq length, num heads, head dim]
934-
935-
ET_KERNEL_CHECK(
1014+
custom_sdpa_out(
9361015
ctx,
937-
resize_tensor(output, q_projected.sizes()) == Error::Ok,
938-
InvalidArgument,
1016+
q_projected,
1017+
key_cache,
1018+
value_cache,
1019+
start_pos,
1020+
seq_len,
1021+
attn_mask,
1022+
dropout_p,
1023+
is_causal,
1024+
scale,
9391025
output);
9401026

941-
// TODO(task): replace the template param selection logic
942-
// with whatever apprpriately makes more sense for
943-
ET_SWITCH_FLOAT_TYPES(
944-
q_projected.scalar_type(), ctx, "flash_attention", CTYPE, [&] {
945-
// TODO we need to re-evaluate this for ARM CPUs
946-
// And there can be many so instead of templatizing
947-
// we might consider another appraoch
948-
if (q_seq_len >= 768) {
949-
cpu_flash_attention<CTYPE, 256, 512>(
950-
output,
951-
q_projected,
952-
sliced_key_cache,
953-
sliced_value_cache,
954-
dropout_p,
955-
is_causal,
956-
attn_mask,
957-
scale,
958-
true,
959-
start_pos);
960-
} else if (q_seq_len >= 192) {
961-
cpu_flash_attention<CTYPE, 64, 512>(
962-
output,
963-
q_projected,
964-
sliced_key_cache,
965-
sliced_value_cache,
966-
dropout_p,
967-
is_causal,
968-
attn_mask,
969-
scale,
970-
true,
971-
start_pos);
972-
} else {
973-
cpu_flash_attention<CTYPE, 32, 512>(
974-
output,
975-
q_projected,
976-
sliced_key_cache,
977-
sliced_value_cache,
978-
dropout_p,
979-
is_causal,
980-
attn_mask,
981-
scale,
982-
true,
983-
start_pos);
984-
}
985-
});
9861027
return output;
9871028
}
9881029
} // namespace native

0 commit comments

Comments
 (0)