@@ -217,48 +217,51 @@ PagedAttention<T>::PagedAttention(const OpKernelInfo& info) : CudaKernel(info),
217217 ORT_ENFORCE (info.GetAttr (" mask_type" , &mask_type_).IsOK () && (mask_type_ == " normal" || mask_type_ == " alibi" || mask_type_ == " RoPE" ));
218218}
219219
220- template <typename T>
221- void MemoryEfficientAttn (const cudaDeviceProp& device_prop, cudaStream_t stream,
222- const Tensor* query, const Tensor* key, const Tensor* value,
223- Tensor* output, const InputMetadata* input_metadata,
224- PackedAttentionParameters params) {
225- MemoryEfficientAttentionParams attn_param;
226- attn_param.sm = device_prop.major * 10 + device_prop.minor ;
227- attn_param.is_half = sizeof (T) == 2 ;
228- attn_param.batch_size = input_metadata->attn_bias .batchsize ;
229- attn_param.num_heads = params.num_heads ;
230- attn_param.sequence_length = input_metadata->attn_bias .q_seqinfo .max_seqlen ;
231- attn_param.kv_sequence_length = 0 ;
232- attn_param.qk_head_size = params.head_size ;
233- attn_param.v_head_size = params.head_size ;
234- attn_param.causal = true ;
235- attn_param.scale = params.scale ;
236- attn_param.seqlen_k_ptr = nullptr ;
237- attn_param.seqstart_q_ptr = reinterpret_cast <int32_t *>(input_metadata->attn_bias .q_seqinfo .seqstart );
238- attn_param.seqstart_k_ptr = reinterpret_cast <int32_t *>(input_metadata->attn_bias .q_seqinfo .seqstart );
239- attn_param.query = query->DataRaw ();
240- attn_param.key = key->DataRaw ();
241- attn_param.value = value->DataRaw ();
242- attn_param.attn_bias = nullptr ;
243- attn_param.is_attn_bias_batched = false ;
244- attn_param.output = output->MutableDataRaw ();
245- attn_param.workspace = nullptr ;
246- attn_param.stream = stream;
247- run_memory_efficient_attention (attn_param);
248- }
249-
250220template <typename T>
251221Status PagedAttention<T>::CheckInputs(
252- const Tensor* query,
253- const Tensor* key,
254- const Tensor* value,
222+ OpKernelContext* context,
255223 const InputMetadata* input_metadata,
256224 PackedAttentionParameters& parameters) const {
257- ORT_UNUSED_PARAMETER (query);
258- ORT_UNUSED_PARAMETER (key);
259- ORT_UNUSED_PARAMETER (value);
225+ const Tensor* query = context->Input <Tensor>(0 );
226+ const Tensor* key_cache = context->Input <Tensor>(3 );
227+ const Tensor* value_cache = context->Input <Tensor>(4 );
228+ const Tensor* positions = context->Input <Tensor>(6 );
229+
230+ const auto & query_shape = query->Shape ();
231+ if (query_shape.NumDimensions () < 2 || query_shape.NumDimensions () > 3 ) {
232+ return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " Invalid query shape: " , query_shape, " expected 2 or 3 dimensions" );
233+ }
234+ int64_t batch_size = 1 ;
235+ int64_t seq_len = query_shape[0 ];
236+ if (query_shape.NumDimensions () == 3 ) {
237+ batch_size = query_shape[0 ];
238+ seq_len = query_shape[1 ];
239+ }
240+
241+ if (batch_size != 1 && input_metadata->num_prompt_tokens * input_metadata->num_generation_tokens != 0 ) {
242+ return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL,
243+ " Invalid input_medata, batch_size should be 1 when prompt"
244+ " and generation tokens are both present" );
245+ }
246+
247+ if (batch_size * seq_len < input_metadata->num_valid_tokens ) {
248+ return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " Invalid query shape: " , query_shape,
249+ " expected at least " , input_metadata->num_valid_tokens , " tokens" );
250+ }
251+
252+ if (key_cache->Shape ().NumDimensions () != 5 || value_cache->Shape ().NumDimensions () != 4 ) {
253+ return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " Invalid key_cache or value_cache shape: " ,
254+ key_cache->Shape (), " " , value_cache->Shape ());
255+ }
256+
257+ if (positions && positions->Shape ().Size () > 0 &&
258+ positions->Shape ()[positions->Shape ().NumDimensions () - 1 ] != batch_size * seq_len) {
259+ return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " Invalid positions shape: " , positions->Shape ());
260+ }
261+
260262 int64_t num_prompt_tokens = input_metadata->num_prompt_tokens ;
261263
264+ // padding removed
262265 parameters.batch_size = input_metadata->attn_bias .batchsize ;
263266 // parameters.sequence_length = gsl::narrow<int>(num_prompt_tokens);
264267 parameters.head_size = head_size_;
@@ -286,15 +289,14 @@ Status PagedAttention<T>::CheckInputs(
286289
287290template <typename T>
288291Status PagedAttention<T>::DoQKVProjectionIfNeed(OpKernelContext* context,
292+ InputMetadata* input_metadata,
289293 PackedAttentionParameters parameters,
290294 IAllocatorUniquePtr<T>& gemm_buffer) const {
291295 const Tensor* query = context->Input <Tensor>(0 );
292296 const Tensor* key = context->Input <Tensor>(1 );
293297 const Tensor* value = context->Input <Tensor>(2 );
294298
295- const Tensor* t_input_metadata = context->Input <Tensor>(5 );
296- InputMetadata* input_metadata = reinterpret_cast <InputMetadata*>(t_input_metadata->Data <int64_t >()[0 ]);
297-
299+ // query for input, key for weights, value for bias, so their dimensions are different.
298300 if (key->Shape ().NumDimensions () == value->Shape ().NumDimensions ()) {
299301 return Status::OK ();
300302 }
@@ -337,13 +339,12 @@ Status PagedAttention<T>::DoQKVProjectionIfNeed(OpKernelContext* context,
337339
338340template <typename T>
339341Status PagedAttention<T>::RunMultiHeadAttention(Tensor* output, OpKernelContext* context,
342+ InputMetadata* input_metadata,
340343 PackedAttentionParameters parameters,
341344 IAllocatorUniquePtr<T>& gemm_buffer) const {
342345 const Tensor* query = context->Input <Tensor>(0 );
343346 const Tensor* key = context->Input <Tensor>(1 );
344347 const Tensor* value = context->Input <Tensor>(2 );
345- const Tensor* t_input_metadata = context->Input <Tensor>(5 );
346- InputMetadata* input_metadata = reinterpret_cast <InputMetadata*>(t_input_metadata->Data <int64_t >()[0 ]);
347348
348349 const Tensor* bias = nullptr ;
349350 const Tensor* relative_position_bias = nullptr ;
@@ -402,27 +403,120 @@ Status PagedAttention<T>::RunMultiHeadAttention(Tensor* output, OpKernelContext*
402403 data.source_qkv_format = (key == nullptr ) ? AttentionQkvFormat::QKV_TN3H : AttentionQkvFormat::Q_K_V_TNH;
403404 return QkvToContext<CudaT>(device_prop, cublas, this ->Stream (context), parameters, data);
404405}
406+
407+ InputMetadata* GetOrCreateMedataFromInput (OpKernelContext* context, InputMetadata* s_input_metadata, int8_t * meta_data_space) {
408+ const Tensor* t_input_metadata = context->Input <Tensor>(5 );
409+ if (t_input_metadata && t_input_metadata->Data <int64_t >()[0 ]) {
410+ return reinterpret_cast <InputMetadata*>(t_input_metadata->Data <int64_t >()[0 ]);
411+ }
412+
413+ const Tensor* query = context->Input <Tensor>(0 );
414+ const Tensor* key_cache = context->Input <Tensor>(3 );
415+ int seq_len = query->Shape ().NumDimensions () == 3 ? query->Shape ()[1 ] : query->Shape ()[0 ];
416+
417+ const Tensor* positions = context->Input <Tensor>(6 );
418+ std::vector<int64_t > cpu_position (positions->Shape ().Size ());
419+ CUDA_CALL_THROW (cudaMemcpy (cpu_position.data (), positions->DataRaw (),
420+ positions->SizeInBytes (), cudaMemcpyDeviceToHost));
421+ while (cpu_position.back () == 0 ) {
422+ cpu_position.pop_back ();
423+ seq_len--;
424+ }
425+ InputMetadata* input_metadata = s_input_metadata;
426+ input_metadata->num_valid_tokens = seq_len;
427+
428+ std::vector<int32_t > slot_mapping;
429+ std::vector<int32_t > context_lens;
430+ std::vector<int32_t > block_tables;
431+ std::vector<int32_t > seqstart;
432+
433+ input_metadata->max_context_len = 0 ;
434+ // in prompt mode
435+ if (cpu_position.back () == 0 ||
436+ cpu_position.size () > 1 ) {
437+ input_metadata->num_prompt_tokens = seq_len;
438+ input_metadata->num_generation_tokens = 0 ;
439+ slot_mapping.resize (input_metadata->num_prompt_tokens );
440+ std::iota (slot_mapping.begin (), slot_mapping.end (), 0 );
441+ } else {
442+ int32_t block_size = gsl::narrow<int32_t >(key_cache->Shape ()[3 ]);
443+ int32_t past_seq_len = cpu_position.back ();
444+ input_metadata->num_prompt_tokens = 0 ;
445+ input_metadata->num_generation_tokens = seq_len;
446+ slot_mapping.push_back (past_seq_len);
447+ context_lens.push_back (past_seq_len + 1 );
448+ for (int i = 0 ; i < past_seq_len + 1 ; i += block_size) {
449+ block_tables.push_back (i / block_size);
450+ }
451+ input_metadata->max_context_len = context_lens.back ();
452+ }
453+
454+ if (block_tables.empty ()) {
455+ input_metadata->block_tables = 0 ;
456+ } else {
457+ // copy to cuda
458+ CUDA_CALL_THROW (cudaMemcpy (meta_data_space, block_tables.data (),
459+ block_tables.size () * sizeof (int32_t ), cudaMemcpyHostToDevice));
460+ input_metadata->block_tables = reinterpret_cast <int64_t >(meta_data_space);
461+ meta_data_space += block_tables.size () * sizeof (int32_t );
462+ }
463+ if (context_lens.empty ()) {
464+ input_metadata->context_lens = 0 ;
465+ } else {
466+ // copy to cuda
467+ CUDA_CALL_THROW (cudaMemcpy (meta_data_space, context_lens.data (),
468+ context_lens.size () * sizeof (int32_t ), cudaMemcpyHostToDevice));
469+ input_metadata->context_lens = reinterpret_cast <int64_t >(meta_data_space);
470+ meta_data_space += context_lens.size () * sizeof (int32_t );
471+ }
472+ {
473+ // copy to cuda
474+ CUDA_CALL_THROW (cudaMemcpy (meta_data_space, slot_mapping.data (),
475+ slot_mapping.size () * sizeof (int32_t ), cudaMemcpyHostToDevice));
476+ input_metadata->slot_mapping = reinterpret_cast <int64_t >(meta_data_space);
477+ meta_data_space += slot_mapping.size () * sizeof (int32_t );
478+ }
479+ input_metadata->max_num_blocks_per_seq = block_tables.size ();
480+ std::memset (input_metadata->cache_events .events , 0 , sizeof (THEvent));
481+
482+ if (input_metadata->num_prompt_tokens > 0 ) {
483+ seqstart.push_back (0 );
484+ seqstart.push_back (input_metadata->num_prompt_tokens );
485+
486+ // copy to cuda
487+ CUDA_CALL_THROW (cudaMemcpy (meta_data_space, seqstart.data (),
488+ seqstart.size () * sizeof (int32_t ), cudaMemcpyHostToDevice));
489+ input_metadata->attn_bias .q_seqinfo .seqstart = reinterpret_cast <int64_t >(meta_data_space);
490+ input_metadata->attn_bias .q_seqinfo .max_seqlen = input_metadata->num_prompt_tokens ;
491+ input_metadata->attn_bias .batchsize = 1 ;
492+ meta_data_space += seqstart.size () * sizeof (int32_t );
493+ }
494+
495+ return input_metadata;
496+ }
497+
405498template <typename T>
406499Status PagedAttention<T>::ComputeInternal(OpKernelContext* context) const {
407500 const Tensor* query = context->Input <Tensor>(0 );
408501 const Tensor* key = context->Input <Tensor>(1 );
409502 const Tensor* value = context->Input <Tensor>(2 );
410503 const Tensor* key_cache = context->Input <Tensor>(3 );
411504 const Tensor* value_cache = context->Input <Tensor>(4 );
412- const Tensor* t_input_metadata = context->Input <Tensor>(5 );
413505 const Tensor* positions = context->Input <Tensor>(6 );
414506 const Tensor* cos_sin_cache = context->Input <Tensor>(7 );
415507 const Tensor* kv_quant_param = (context->InputCount () > 8 ) ? context->Input <Tensor>(8 ) : nullptr ;
416508 ORT_UNUSED_PARAMETER (kv_quant_param);
417509
418- InputMetadata* input_metadata = reinterpret_cast <InputMetadata*>(t_input_metadata->Data <int64_t >()[0 ]);
510+ int seq_len = query->Shape ().NumDimensions () == 3 ? query->Shape ()[1 ] : query->Shape ()[0 ];
511+ auto meta_data_space = this ->template GetScratchBuffer <int8_t >(std::max (1024 , seq_len * 3 ) * sizeof (int32_t ), context->GetComputeStream ());
512+ InputMetadata self_build_input_metadata;
513+ InputMetadata* input_metadata = GetOrCreateMedataFromInput (context, &self_build_input_metadata, meta_data_space.get ());
419514
420- const auto & device_prop = GetDeviceProp ();
421515 PackedAttentionParameters parameters;
422- ORT_RETURN_IF_ERROR (CheckInputs (query, key, value , input_metadata, parameters));
516+ ORT_RETURN_IF_ERROR (CheckInputs (context , input_metadata, parameters));
423517
424518 IAllocatorUniquePtr<T> gemm_buffer;
425- ORT_RETURN_IF_ERROR (DoQKVProjectionIfNeed (context, parameters, gemm_buffer));
519+ ORT_RETURN_IF_ERROR (DoQKVProjectionIfNeed (context, input_metadata, parameters, gemm_buffer));
426520
427521 T* query_data = const_cast <T*>(query->Data <T>());
428522 T* key_data = const_cast <T*>(key->Data <T>());
@@ -459,20 +553,9 @@ Status PagedAttention<T>::ComputeInternal(OpKernelContext* context) const {
459553 }
460554
461555 int64_t num_prompt_tokens = input_metadata->num_prompt_tokens ;
462- bool use_multihead_attn = ParseEnvironmentVariableWithDefault<bool >(" use_multihead_attn" , true );
463-
464556 if (num_prompt_tokens > 0 ) {
465- if (use_multihead_attn) {
466- ORT_RETURN_IF_ERROR (RunMultiHeadAttention (output, context, parameters, gemm_buffer));
467- } else {
468- MemoryEfficientAttn<MLFloat16>(device_prop, Stream (context),
469- query,
470- key,
471- value,
472- output,
473- input_metadata,
474- parameters);
475- }
557+ ORT_RETURN_IF_ERROR (RunMultiHeadAttention (output, context, input_metadata, parameters, gemm_buffer));
558+ CHECK_CUDA_ERROR ();
476559 }
477560
478561 auto key_cache_shape = key_cache->Shape ();
0 commit comments