@@ -754,6 +754,152 @@ void update_cache(
754
754
}
755
755
}
756
756
757
+ /*
758
+ Input params
759
+ @param[in] q_projected Projected query with query weights.
760
+ Format [n_layers, batch size, seq_len, num heads, head dim]
761
+ @param[in] k_projected Projected query with key weights.
762
+ Format [n_layers, batch size, seq_len, num heads, head dim]
763
+ @param[in] v_projected Projected query with value weights.
764
+ Format [n_layers, batch size, seq_len, num heads, head dim]
765
+ @param[in] key_cache Cache of previous k_projected.
766
+ Format [n_layers, batch size, max_seq_len, num heads, head dim]
767
+ @param[in] key_cache Cache of previous v_projected.
768
+ Format [n_layers, batch size, max_seq_len, num heads, head dim]
769
+ ....
770
+ @param[in] start_pos: sequence position
771
+ @param[in] seq_len: Seq length. e.g. seq_len dim of q_projected.
772
+ */
773
+ Tensor& custom_sdpa_out (
774
+ RuntimeContext& ctx,
775
+ const Tensor& q,
776
+ const Tensor& k,
777
+ const Tensor& v,
778
+ const int64_t start_pos,
779
+ const int64_t seq_len,
780
+ const optional<Tensor>& attn_mask,
781
+ const double dropout_p,
782
+ const bool is_causal,
783
+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
784
+ const optional<double > scale,
785
+ Tensor& output) {
786
+ ET_KERNEL_CHECK_MSG (
787
+ ctx,
788
+ !attn_mask.has_value () || !is_causal,
789
+ InvalidArgument,
790
+ output,
791
+ " attn_mask and is_causal cannot be set at the same time" );
792
+
793
+ ET_CHECK_MSG (q.dim () == 4 , " query must be a 4D tensor" );
794
+
795
+ auto q_seq_len = q.size (1 );
796
+
797
+ // Refactor the following into create_view util perhaps using
798
+ // TensorPtr
799
+ std::array<exec_aten::DimOrderType, util::kKVDim > sliced_key_dim_order{
800
+ 0 , 1 , 2 , 3 };
801
+ std::array<exec_aten::SizesType, util::kKVDim > sliced_key_sizes;
802
+ sliced_key_sizes[0 ] = k.size (0 );
803
+ sliced_key_sizes[1 ] = start_pos + seq_len; // key_cache.size(2);
804
+ sliced_key_sizes[2 ] = k.size (2 );
805
+ sliced_key_sizes[3 ] = k.size (3 );
806
+ std::array<exec_aten::StridesType, util::kKVDim > sliced_key_strides;
807
+ dim_order_to_stride_nocheck (
808
+ sliced_key_sizes.data (),
809
+ sliced_key_dim_order.data (),
810
+ util::kKVDim ,
811
+ sliced_key_strides.data ());
812
+ // since the cache is sliced, the batch stride needs to stay the same.
813
+ sliced_key_strides[0 ] = k.strides ()[0 ];
814
+ void * key_cache_data = k.mutable_data_ptr ();
815
+ TensorImpl k_impl = TensorImpl (
816
+ k.scalar_type (),
817
+ util::kKVDim ,
818
+ sliced_key_sizes.data (),
819
+ key_cache_data,
820
+ sliced_key_dim_order.data (),
821
+ sliced_key_strides.data (),
822
+ TensorShapeDynamism::STATIC);
823
+ Tensor sliced_key_cache (&k_impl);
824
+
825
+ std::array<exec_aten::DimOrderType, util::kKVDim > sliced_value_dim_order{
826
+ 0 , 1 , 2 , 3 };
827
+ std::array<exec_aten::SizesType, util::kKVDim > sliced_value_sizes;
828
+ sliced_value_sizes[0 ] = v.size (0 );
829
+ sliced_value_sizes[1 ] = start_pos + seq_len; // value_cache.size(2);
830
+ sliced_value_sizes[2 ] = v.size (2 );
831
+ sliced_value_sizes[3 ] = v.size (3 );
832
+ std::array<exec_aten::StridesType, util::kKVDim > sliced_value_strides;
833
+ dim_order_to_stride_nocheck (
834
+ sliced_value_sizes.data (),
835
+ sliced_value_dim_order.data (),
836
+ util::kKVDim ,
837
+ sliced_value_strides.data ());
838
+ // since the cache is sliced, the batch stride needs to stay the same.
839
+ sliced_value_strides[0 ] = v.strides ()[0 ];
840
+ void * value_cache_data = v.mutable_data_ptr ();
841
+ TensorImpl value_impl = TensorImpl (
842
+ v.scalar_type (),
843
+ util::kKVDim ,
844
+ sliced_value_sizes.data (),
845
+ value_cache_data,
846
+ sliced_value_dim_order.data (),
847
+ sliced_value_strides.data (),
848
+ TensorShapeDynamism::STATIC);
849
+ Tensor sliced_value_cache (&value_impl);
850
+
851
+ ET_KERNEL_CHECK (
852
+ ctx,
853
+ resize_tensor (output, q.sizes ()) == Error::Ok,
854
+ InvalidArgument,
855
+ output);
856
+
857
+ // TODO(task): replace the template param selection logic
858
+ // with whatever apprpriately makes more sense for
859
+ ET_SWITCH_FLOAT_TYPES (q.scalar_type (), ctx, " flash_attention" , CTYPE, [&] {
860
+ // TODO we need to re-evaluate this for ARM CPUs
861
+ // And there can be many so instead of templatizing
862
+ // we might consider another appraoch
863
+ if (q_seq_len >= 768 ) {
864
+ cpu_flash_attention<CTYPE, 256 , 512 >(
865
+ output,
866
+ q,
867
+ sliced_key_cache,
868
+ sliced_value_cache,
869
+ dropout_p,
870
+ is_causal,
871
+ attn_mask,
872
+ scale,
873
+ true ,
874
+ start_pos);
875
+ } else if (q_seq_len >= 192 ) {
876
+ cpu_flash_attention<CTYPE, 64 , 512 >(
877
+ output,
878
+ q,
879
+ sliced_key_cache,
880
+ sliced_value_cache,
881
+ dropout_p,
882
+ is_causal,
883
+ attn_mask,
884
+ scale,
885
+ true ,
886
+ start_pos);
887
+ } else {
888
+ cpu_flash_attention<CTYPE, 32 , 512 >(
889
+ output,
890
+ q,
891
+ sliced_key_cache,
892
+ sliced_value_cache,
893
+ dropout_p,
894
+ is_causal,
895
+ attn_mask,
896
+ scale,
897
+ true ,
898
+ start_pos);
899
+ }
900
+ });
901
+ return output;
902
+ }
757
903
} // anonymous namespace
758
904
759
905
Tensor& flash_attention_kernel_out (
@@ -860,129 +1006,24 @@ Tensor& sdpa_with_kv_cache_out(
860
1006
InvalidArgument,
861
1007
output);
862
1008
863
- ET_KERNEL_CHECK_MSG (
864
- ctx,
865
- !attn_mask.has_value () || !is_causal,
866
- InvalidArgument,
867
- output,
868
- " attn_mask and is_causal cannot be set at the same time" );
869
-
870
1009
ET_CHECK_MSG (q_projected.dim () == 4 , " query must be a 4D tensor" );
871
1010
872
1011
update_cache (k_projected, key_cache, start_pos, seq_len);
873
1012
update_cache (v_projected, value_cache, start_pos, seq_len);
874
1013
875
- auto q_seq_len = q_projected.size (1 );
876
-
877
- std::array<exec_aten::DimOrderType, util::kKVDim > sliced_key_dim_order{
878
- 0 , 1 , 2 , 3 };
879
- std::array<exec_aten::SizesType, util::kKVDim > sliced_key_sizes;
880
- sliced_key_sizes[0 ] = key_cache.size (0 );
881
- sliced_key_sizes[1 ] = start_pos + seq_len; // key_cache.size(2);
882
- sliced_key_sizes[2 ] = key_cache.size (2 );
883
- sliced_key_sizes[3 ] = key_cache.size (3 );
884
- std::array<exec_aten::StridesType, util::kKVDim > sliced_key_strides;
885
- dim_order_to_stride_nocheck (
886
- sliced_key_sizes.data (),
887
- sliced_key_dim_order.data (),
888
- util::kKVDim ,
889
- sliced_key_strides.data ());
890
- // since the cache is sliced, the batch stride needs to stay the same.
891
- sliced_key_strides[0 ] = key_cache.strides ()[0 ];
892
- void * key_cache_data = key_cache.mutable_data_ptr ();
893
- TensorImpl k_impl = TensorImpl (
894
- key_cache.scalar_type (),
895
- util::kKVDim ,
896
- sliced_key_sizes.data (),
897
- key_cache_data,
898
- sliced_key_dim_order.data (),
899
- sliced_key_strides.data (),
900
- TensorShapeDynamism::STATIC);
901
- Tensor sliced_key_cache (&k_impl);
902
-
903
- std::array<exec_aten::DimOrderType, util::kKVDim > sliced_value_dim_order{
904
- 0 , 1 , 2 , 3 };
905
- std::array<exec_aten::SizesType, util::kKVDim > sliced_value_sizes;
906
- sliced_value_sizes[0 ] = value_cache.size (0 );
907
- sliced_value_sizes[1 ] = start_pos + seq_len; // value_cache.size(2);
908
- sliced_value_sizes[2 ] = value_cache.size (2 );
909
- sliced_value_sizes[3 ] = value_cache.size (3 );
910
- std::array<exec_aten::StridesType, util::kKVDim > sliced_value_strides;
911
- dim_order_to_stride_nocheck (
912
- sliced_value_sizes.data (),
913
- sliced_value_dim_order.data (),
914
- util::kKVDim ,
915
- sliced_value_strides.data ());
916
- // since the cache is sliced, the batch stride needs to stay the same.
917
- sliced_value_strides[0 ] = value_cache.strides ()[0 ];
918
- void * value_cache_data = value_cache.mutable_data_ptr ();
919
- TensorImpl value_impl = TensorImpl (
920
- value_cache.scalar_type (),
921
- util::kKVDim ,
922
- sliced_value_sizes.data (),
923
- value_cache_data,
924
- sliced_value_dim_order.data (),
925
- sliced_value_strides.data (),
926
- TensorShapeDynamism::STATIC);
927
- Tensor sliced_value_cache (&value_impl);
928
-
929
- // Is this true?
930
- // Cant do this as is because the expectation of this kernel is
931
- // that q, k, v are [B, num heads, seq length, head dim]
932
- // and the cache is [B, max seq len, num heads, head dim]
933
- // and q, k, v are all [B, seq length, num heads, head dim]
934
-
935
- ET_KERNEL_CHECK (
1014
+ custom_sdpa_out (
936
1015
ctx,
937
- resize_tensor (output, q_projected.sizes ()) == Error::Ok,
938
- InvalidArgument,
1016
+ q_projected,
1017
+ key_cache,
1018
+ value_cache,
1019
+ start_pos,
1020
+ seq_len,
1021
+ attn_mask,
1022
+ dropout_p,
1023
+ is_causal,
1024
+ scale,
939
1025
output);
940
1026
941
- // TODO(task): replace the template param selection logic
942
- // with whatever apprpriately makes more sense for
943
- ET_SWITCH_FLOAT_TYPES (
944
- q_projected.scalar_type (), ctx, " flash_attention" , CTYPE, [&] {
945
- // TODO we need to re-evaluate this for ARM CPUs
946
- // And there can be many so instead of templatizing
947
- // we might consider another appraoch
948
- if (q_seq_len >= 768 ) {
949
- cpu_flash_attention<CTYPE, 256 , 512 >(
950
- output,
951
- q_projected,
952
- sliced_key_cache,
953
- sliced_value_cache,
954
- dropout_p,
955
- is_causal,
956
- attn_mask,
957
- scale,
958
- true ,
959
- start_pos);
960
- } else if (q_seq_len >= 192 ) {
961
- cpu_flash_attention<CTYPE, 64 , 512 >(
962
- output,
963
- q_projected,
964
- sliced_key_cache,
965
- sliced_value_cache,
966
- dropout_p,
967
- is_causal,
968
- attn_mask,
969
- scale,
970
- true ,
971
- start_pos);
972
- } else {
973
- cpu_flash_attention<CTYPE, 32 , 512 >(
974
- output,
975
- q_projected,
976
- sliced_key_cache,
977
- sliced_value_cache,
978
- dropout_p,
979
- is_causal,
980
- attn_mask,
981
- scale,
982
- true ,
983
- start_pos);
984
- }
985
- });
986
1027
return output;
987
1028
}
988
1029
} // namespace native
0 commit comments