Skip to content

Commit c32c494

Browse files
committed
new changes in vLLM
1 parent 308858a commit c32c494

3 files changed

Lines changed: 574 additions & 168 deletions

File tree

onnxruntime/contrib_ops/cuda/bert/paged_attention.cc

Lines changed: 54 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ Status PagedAttention<T>::CheckInputs(
226226
const Tensor* query = context->Input<Tensor>(0);
227227
const Tensor* key_cache = context->Input<Tensor>(3);
228228
const Tensor* value_cache = context->Input<Tensor>(4);
229-
const Tensor* positions = context->Input<Tensor>(6);
229+
//const Tensor* positions = context->Input<Tensor>(6);
230230

231231
const auto& query_shape = query->Shape();
232232
if (query_shape.NumDimensions() < 2 || query_shape.NumDimensions() > 3) {
@@ -255,10 +255,10 @@ Status PagedAttention<T>::CheckInputs(
255255
key_cache->Shape(), " ", value_cache->Shape());
256256
}
257257

258-
if (positions && positions->Shape().Size() > 0 &&
259-
positions->Shape()[positions->Shape().NumDimensions() - 1] != batch_size * seq_len) {
260-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Invalid positions shape: ", positions->Shape());
261-
}
258+
//if (positions && positions->Shape().Size() > 0 &&
259+
// positions->Shape()[positions->Shape().NumDimensions() - 1] != batch_size * seq_len) {
260+
// return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Invalid positions shape: ", positions->Shape());
261+
//}
262262

263263
int64_t num_prompt_tokens = input_metadata->num_prompt_tokens;
264264

@@ -507,8 +507,8 @@ Status PagedAttention<T>::ComputeInternal(OpKernelContext* context) const {
507507
const Tensor* cos_sin_cache = context->Input<Tensor>(7);
508508
const Tensor* kv_quant_param = (context->InputCount() > 8) ? context->Input<Tensor>(8) : nullptr;
509509
ORT_UNUSED_PARAMETER(kv_quant_param);
510-
511-
int seq_len = query->Shape().NumDimensions() == 3? query->Shape()[1] : query->Shape()[0];
510+
const auto& query_shape = query->Shape();
511+
int seq_len = query_shape[1] * query_shape[0];
512512
auto meta_data_space = this->template GetScratchBuffer<int8_t>(std::max(1024, seq_len * 3) * sizeof(int32_t), context->GetComputeStream());
513513
InputMetadata self_build_input_metadata;
514514
InputMetadata* input_metadata = GetOrCreateMedataFromInput(context, &self_build_input_metadata, meta_data_space.get());
@@ -524,9 +524,9 @@ Status PagedAttention<T>::ComputeInternal(OpKernelContext* context) const {
524524
T* value_data = const_cast<T*>(value->Data<T>());
525525

526526
int64_t num_valid_tokens = input_metadata->num_valid_tokens;
527-
TensorShape output_shape = query->Shape();
527+
TensorShape output_shape = query_shape;
528528
if (gemm_buffer.get() == nullptr) {
529-
ORT_ENFORCE(query->Shape()[1] == num_heads_ * head_size_, "invlaid query shape");
529+
ORT_ENFORCE(query_shape[output_shape.NumDimensions() - 1] == num_heads_ * head_size_, "invlaid query shape");
530530
} else {
531531
output_shape[output_shape.NumDimensions() - 1] = num_heads_ * head_size_;
532532
TensorShapeVector new_shape(2);
@@ -548,7 +548,8 @@ Status PagedAttention<T>::ComputeInternal(OpKernelContext* context) const {
548548
int64_t rot_dim = cos_sin_cache->Shape()[1];
549549
ORT_ENFORCE(rot_dim == head_size_, "RoPE mask requires position input with shape [seq_len, head_size]");
550550
rotary_embedding_neox(Stream(context), positions->Data<int64_t>(), static_cast<void*>(query_data),
551-
static_cast<void*>(key_data), head_size_, cos_sin_cache->DataRaw(), num_valid_tokens,
551+
static_cast<void*>(key_data), head_size_, cos_sin_cache->DataRaw(),
552+
query_shape.Size()/head_size_/num_heads_,
552553
rot_dim, num_heads_, num_kv_heads_, 1);
553554
CHECK_CUDA_ERROR();
554555
}
@@ -584,22 +585,50 @@ Status PagedAttention<T>::ComputeInternal(OpKernelContext* context) const {
584585
}
585586

586587
if (input_metadata->num_generation_tokens > 0) {
588+
constexpr int PARTITION_SIZE = 512;
589+
int max_num_partitions = ((input_metadata->max_context_len + PARTITION_SIZE - 1) / PARTITION_SIZE);
590+
//TODO : Tune this heuristic.
591+
bool use_v1 = max_num_partitions == 1 || (query_shape[0] * query_shape[1]) > PARTITION_SIZE;
587592
int64_t generation_qeury_shape[3] = {num_valid_tokens - num_prompt_tokens, num_heads_, head_size_};
588-
single_query_cached_kv_attention(Stream(context),
589-
output->MutableData<MLFloat16>() + num_prompt_tokens * num_heads_ * head_size_,
590-
query_data + num_prompt_tokens * num_heads_ * head_size_,
591-
key_cache->Data<MLFloat16>(),
592-
value_cache->Data<MLFloat16>(),
593-
head_mapping_.get(),
594-
scale_,
595-
reinterpret_cast<const int32_t*>(input_metadata->block_tables),
596-
input_metadata->max_num_blocks_per_seq,
597-
reinterpret_cast<const int32_t*>(input_metadata->context_lens),
598-
value_cache->Shape()[3],
599-
input_metadata->max_context_len,
600-
nullptr,
601-
generation_qeury_shape,
602-
num_queries_per_kv_, 1);
593+
if (use_v1){
594+
paged_attention_v1(Stream(context),
595+
output->MutableData<MLFloat16>() + num_prompt_tokens * num_heads_ * head_size_,
596+
query_data + num_prompt_tokens * num_heads_ * head_size_,
597+
key_cache->Data<MLFloat16>(),
598+
value_cache->Data<MLFloat16>(),
599+
head_mapping_.get(),
600+
scale_,
601+
reinterpret_cast<const int32_t*>(input_metadata->block_tables),
602+
reinterpret_cast<const int32_t*>(input_metadata->context_lens),
603+
value_cache->Shape()[3],
604+
input_metadata->max_context_len,
605+
nullptr,
606+
input_metadata->max_num_blocks_per_seq,
607+
generation_qeury_shape,
608+
num_queries_per_kv_, 1);
609+
} else {
610+
auto tmp_output = this->template GetScratchBuffer<T>(query_shape.Size() * max_num_partitions * sizeof(T), context->GetComputeStream());
611+
auto exp_sums = this->template GetScratchBuffer<T>(query_shape[0] * query_shape [1]* max_num_partitions * sizeof(T), context->GetComputeStream());
612+
auto max_logits = this->template GetScratchBuffer<T>(query_shape[0] * query_shape[1] * max_num_partitions * sizeof(T), context->GetComputeStream());
613+
paged_attention_v2(Stream(context),
614+
output->MutableData<MLFloat16>() + num_prompt_tokens * num_heads_ * head_size_,
615+
exp_sums.get(),
616+
max_logits.get(),
617+
tmp_output.get(),
618+
query_data + num_prompt_tokens * num_heads_ * head_size_,
619+
key_cache->Data<MLFloat16>(),
620+
value_cache->Data<MLFloat16>(),
621+
head_mapping_.get(),
622+
scale_,
623+
reinterpret_cast<const int32_t*>(input_metadata->block_tables),
624+
reinterpret_cast<const int32_t*>(input_metadata->context_lens),
625+
value_cache->Shape()[3],
626+
input_metadata->max_context_len,
627+
nullptr,
628+
input_metadata->max_num_blocks_per_seq,
629+
generation_qeury_shape,
630+
num_queries_per_kv_, 1);
631+
}
603632
CHECK_CUDA_ERROR();
604633
}
605634
return Status::OK();

0 commit comments

Comments
 (0)