@@ -226,7 +226,7 @@ Status PagedAttention<T>::CheckInputs(
226226 const Tensor* query = context->Input <Tensor>(0 );
227227 const Tensor* key_cache = context->Input <Tensor>(3 );
228228 const Tensor* value_cache = context->Input <Tensor>(4 );
229- const Tensor* positions = context->Input <Tensor>(6 );
229+ // const Tensor* positions = context->Input<Tensor>(6);
230230
231231 const auto & query_shape = query->Shape ();
232232 if (query_shape.NumDimensions () < 2 || query_shape.NumDimensions () > 3 ) {
@@ -255,10 +255,10 @@ Status PagedAttention<T>::CheckInputs(
255255 key_cache->Shape (), " " , value_cache->Shape ());
256256 }
257257
258- if (positions && positions->Shape ().Size () > 0 &&
259- positions->Shape ()[positions->Shape ().NumDimensions () - 1 ] != batch_size * seq_len) {
260- return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " Invalid positions shape: " , positions->Shape ());
261- }
258+ // if (positions && positions->Shape().Size() > 0 &&
259+ // positions->Shape()[positions->Shape().NumDimensions() - 1] != batch_size * seq_len) {
260+ // return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Invalid positions shape: ", positions->Shape());
261+ // }
262262
263263 int64_t num_prompt_tokens = input_metadata->num_prompt_tokens ;
264264
@@ -507,8 +507,8 @@ Status PagedAttention<T>::ComputeInternal(OpKernelContext* context) const {
507507 const Tensor* cos_sin_cache = context->Input <Tensor>(7 );
508508 const Tensor* kv_quant_param = (context->InputCount () > 8 ) ? context->Input <Tensor>(8 ) : nullptr ;
509509 ORT_UNUSED_PARAMETER (kv_quant_param);
510-
511- int seq_len = query-> Shape (). NumDimensions () == 3 ? query-> Shape () [1 ] : query-> Shape () [0 ];
510+ const auto & query_shape = query-> Shape ();
511+ int seq_len = query_shape [1 ] * query_shape [0 ];
512512 auto meta_data_space = this ->template GetScratchBuffer <int8_t >(std::max (1024 , seq_len * 3 ) * sizeof (int32_t ), context->GetComputeStream ());
513513 InputMetadata self_build_input_metadata;
514514 InputMetadata* input_metadata = GetOrCreateMedataFromInput (context, &self_build_input_metadata, meta_data_space.get ());
@@ -524,9 +524,9 @@ Status PagedAttention<T>::ComputeInternal(OpKernelContext* context) const {
524524 T* value_data = const_cast <T*>(value->Data <T>());
525525
526526 int64_t num_valid_tokens = input_metadata->num_valid_tokens ;
527- TensorShape output_shape = query-> Shape () ;
527+ TensorShape output_shape = query_shape ;
528528 if (gemm_buffer.get () == nullptr ) {
529- ORT_ENFORCE (query-> Shape ()[ 1 ] == num_heads_ * head_size_, " invlaid query shape" );
529+ ORT_ENFORCE (query_shape[output_shape. NumDimensions () - 1 ] == num_heads_ * head_size_, " invlaid query shape" );
530530 } else {
531531 output_shape[output_shape.NumDimensions () - 1 ] = num_heads_ * head_size_;
532532 TensorShapeVector new_shape (2 );
@@ -548,7 +548,8 @@ Status PagedAttention<T>::ComputeInternal(OpKernelContext* context) const {
548548 int64_t rot_dim = cos_sin_cache->Shape ()[1 ];
549549 ORT_ENFORCE (rot_dim == head_size_, " RoPE mask requires position input with shape [seq_len, head_size]" );
550550 rotary_embedding_neox (Stream (context), positions->Data <int64_t >(), static_cast <void *>(query_data),
551- static_cast <void *>(key_data), head_size_, cos_sin_cache->DataRaw (), num_valid_tokens,
551+ static_cast <void *>(key_data), head_size_, cos_sin_cache->DataRaw (),
552+ query_shape.Size ()/head_size_/num_heads_,
552553 rot_dim, num_heads_, num_kv_heads_, 1 );
553554 CHECK_CUDA_ERROR ();
554555 }
@@ -584,22 +585,50 @@ Status PagedAttention<T>::ComputeInternal(OpKernelContext* context) const {
584585 }
585586
586587 if (input_metadata->num_generation_tokens > 0 ) {
588+ constexpr int PARTITION_SIZE = 512 ;
589+ int max_num_partitions = ((input_metadata->max_context_len + PARTITION_SIZE - 1 ) / PARTITION_SIZE);
590+ // TODO : Tune this heuristic.
591+ bool use_v1 = max_num_partitions == 1 || (query_shape[0 ] * query_shape[1 ]) > PARTITION_SIZE;
587592 int64_t generation_qeury_shape[3 ] = {num_valid_tokens - num_prompt_tokens, num_heads_, head_size_};
588- single_query_cached_kv_attention (Stream (context),
589- output->MutableData <MLFloat16>() + num_prompt_tokens * num_heads_ * head_size_,
590- query_data + num_prompt_tokens * num_heads_ * head_size_,
591- key_cache->Data <MLFloat16>(),
592- value_cache->Data <MLFloat16>(),
593- head_mapping_.get (),
594- scale_,
595- reinterpret_cast <const int32_t *>(input_metadata->block_tables ),
596- input_metadata->max_num_blocks_per_seq ,
597- reinterpret_cast <const int32_t *>(input_metadata->context_lens ),
598- value_cache->Shape ()[3 ],
599- input_metadata->max_context_len ,
600- nullptr ,
601- generation_qeury_shape,
602- num_queries_per_kv_, 1 );
593+ if (use_v1){
594+ paged_attention_v1 (Stream (context),
595+ output->MutableData <MLFloat16>() + num_prompt_tokens * num_heads_ * head_size_,
596+ query_data + num_prompt_tokens * num_heads_ * head_size_,
597+ key_cache->Data <MLFloat16>(),
598+ value_cache->Data <MLFloat16>(),
599+ head_mapping_.get (),
600+ scale_,
601+ reinterpret_cast <const int32_t *>(input_metadata->block_tables ),
602+ reinterpret_cast <const int32_t *>(input_metadata->context_lens ),
603+ value_cache->Shape ()[3 ],
604+ input_metadata->max_context_len ,
605+ nullptr ,
606+ input_metadata->max_num_blocks_per_seq ,
607+ generation_qeury_shape,
608+ num_queries_per_kv_, 1 );
609+ } else {
610+ auto tmp_output = this ->template GetScratchBuffer <T>(query_shape.Size () * max_num_partitions * sizeof (T), context->GetComputeStream ());
611+ auto exp_sums = this ->template GetScratchBuffer <T>(query_shape[0 ] * query_shape [1 ]* max_num_partitions * sizeof (T), context->GetComputeStream ());
612+ auto max_logits = this ->template GetScratchBuffer <T>(query_shape[0 ] * query_shape[1 ] * max_num_partitions * sizeof (T), context->GetComputeStream ());
613+ paged_attention_v2 (Stream (context),
614+ output->MutableData <MLFloat16>() + num_prompt_tokens * num_heads_ * head_size_,
615+ exp_sums.get (),
616+ max_logits.get (),
617+ tmp_output.get (),
618+ query_data + num_prompt_tokens * num_heads_ * head_size_,
619+ key_cache->Data <MLFloat16>(),
620+ value_cache->Data <MLFloat16>(),
621+ head_mapping_.get (),
622+ scale_,
623+ reinterpret_cast <const int32_t *>(input_metadata->block_tables ),
624+ reinterpret_cast <const int32_t *>(input_metadata->context_lens ),
625+ value_cache->Shape ()[3 ],
626+ input_metadata->max_context_len ,
627+ nullptr ,
628+ input_metadata->max_num_blocks_per_seq ,
629+ generation_qeury_shape,
630+ num_queries_per_kv_, 1 );
631+ }
603632 CHECK_CUDA_ERROR ();
604633 }
605634 return Status::OK ();
0 commit comments