Skip to content

Commit e548164

Browse files
committed
try support metadata optional
1 parent 4ebf976 commit e548164

3 files changed

Lines changed: 168 additions & 68 deletions

File tree

onnxruntime/contrib_ops/cuda/bert/paged_attention.cc

Lines changed: 142 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -217,48 +217,51 @@ PagedAttention<T>::PagedAttention(const OpKernelInfo& info) : CudaKernel(info),
217217
ORT_ENFORCE(info.GetAttr("mask_type", &mask_type_).IsOK() && (mask_type_ == "normal" || mask_type_ == "alibi" || mask_type_ == "RoPE"));
218218
}
219219

220-
template <typename T>
221-
void MemoryEfficientAttn(const cudaDeviceProp& device_prop, cudaStream_t stream,
222-
const Tensor* query, const Tensor* key, const Tensor* value,
223-
Tensor* output, const InputMetadata* input_metadata,
224-
PackedAttentionParameters params) {
225-
MemoryEfficientAttentionParams attn_param;
226-
attn_param.sm = device_prop.major * 10 + device_prop.minor;
227-
attn_param.is_half = sizeof(T) == 2;
228-
attn_param.batch_size = input_metadata->attn_bias.batchsize;
229-
attn_param.num_heads = params.num_heads;
230-
attn_param.sequence_length = input_metadata->attn_bias.q_seqinfo.max_seqlen;
231-
attn_param.kv_sequence_length = 0;
232-
attn_param.qk_head_size = params.head_size;
233-
attn_param.v_head_size = params.head_size;
234-
attn_param.causal = true;
235-
attn_param.scale = params.scale;
236-
attn_param.seqlen_k_ptr = nullptr;
237-
attn_param.seqstart_q_ptr = reinterpret_cast<int32_t*>(input_metadata->attn_bias.q_seqinfo.seqstart);
238-
attn_param.seqstart_k_ptr = reinterpret_cast<int32_t*>(input_metadata->attn_bias.q_seqinfo.seqstart);
239-
attn_param.query = query->DataRaw();
240-
attn_param.key = key->DataRaw();
241-
attn_param.value = value->DataRaw();
242-
attn_param.attn_bias = nullptr;
243-
attn_param.is_attn_bias_batched = false;
244-
attn_param.output = output->MutableDataRaw();
245-
attn_param.workspace = nullptr;
246-
attn_param.stream = stream;
247-
run_memory_efficient_attention(attn_param);
248-
}
249-
250220
template <typename T>
251221
Status PagedAttention<T>::CheckInputs(
252-
const Tensor* query,
253-
const Tensor* key,
254-
const Tensor* value,
222+
OpKernelContext* context,
255223
const InputMetadata* input_metadata,
256224
PackedAttentionParameters& parameters) const {
257-
ORT_UNUSED_PARAMETER(query);
258-
ORT_UNUSED_PARAMETER(key);
259-
ORT_UNUSED_PARAMETER(value);
225+
const Tensor* query = context->Input<Tensor>(0);
226+
const Tensor* key_cache = context->Input<Tensor>(3);
227+
const Tensor* value_cache = context->Input<Tensor>(4);
228+
const Tensor* positions = context->Input<Tensor>(6);
229+
230+
const auto& query_shape = query->Shape();
231+
if (query_shape.NumDimensions() < 2 || query_shape.NumDimensions() > 3) {
232+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Invalid query shape: ", query_shape, " expected 2 or 3 dimensions");
233+
}
234+
int64_t batch_size = 1;
235+
int64_t seq_len = query_shape[0];
236+
if (query_shape.NumDimensions() == 3) {
237+
batch_size = query_shape[0];
238+
seq_len = query_shape[1];
239+
}
240+
241+
if (batch_size != 1 && input_metadata->num_prompt_tokens * input_metadata->num_generation_tokens != 0) {
242+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
243+
"Invalid input_medata, batch_size should be 1 when prompt"
244+
" and generation tokens are both present");
245+
}
246+
247+
if (batch_size * seq_len < input_metadata->num_valid_tokens) {
248+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Invalid query shape: ", query_shape,
249+
" expected at least ", input_metadata->num_valid_tokens, " tokens");
250+
}
251+
252+
if (key_cache->Shape().NumDimensions() != 5 || value_cache->Shape().NumDimensions() != 4) {
253+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Invalid key_cache or value_cache shape: ",
254+
key_cache->Shape(), " ", value_cache->Shape());
255+
}
256+
257+
if (positions && positions->Shape().Size() > 0 &&
258+
positions->Shape()[positions->Shape().NumDimensions() - 1] != batch_size * seq_len) {
259+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Invalid positions shape: ", positions->Shape());
260+
}
261+
260262
int64_t num_prompt_tokens = input_metadata->num_prompt_tokens;
261263

264+
// padding removed
262265
parameters.batch_size = input_metadata->attn_bias.batchsize;
263266
// parameters.sequence_length = gsl::narrow<int>(num_prompt_tokens);
264267
parameters.head_size = head_size_;
@@ -286,15 +289,14 @@ Status PagedAttention<T>::CheckInputs(
286289

287290
template <typename T>
288291
Status PagedAttention<T>::DoQKVProjectionIfNeed(OpKernelContext* context,
292+
InputMetadata* input_metadata,
289293
PackedAttentionParameters parameters,
290294
IAllocatorUniquePtr<T>& gemm_buffer) const {
291295
const Tensor* query = context->Input<Tensor>(0);
292296
const Tensor* key = context->Input<Tensor>(1);
293297
const Tensor* value = context->Input<Tensor>(2);
294298

295-
const Tensor* t_input_metadata = context->Input<Tensor>(5);
296-
InputMetadata* input_metadata = reinterpret_cast<InputMetadata*>(t_input_metadata->Data<int64_t>()[0]);
297-
299+
// query for input, key for weights, value for bias, so their dimensions are different.
298300
if (key->Shape().NumDimensions() == value->Shape().NumDimensions()) {
299301
return Status::OK();
300302
}
@@ -337,13 +339,12 @@ Status PagedAttention<T>::DoQKVProjectionIfNeed(OpKernelContext* context,
337339

338340
template <typename T>
339341
Status PagedAttention<T>::RunMultiHeadAttention(Tensor* output, OpKernelContext* context,
342+
InputMetadata* input_metadata,
340343
PackedAttentionParameters parameters,
341344
IAllocatorUniquePtr<T>& gemm_buffer) const {
342345
const Tensor* query = context->Input<Tensor>(0);
343346
const Tensor* key = context->Input<Tensor>(1);
344347
const Tensor* value = context->Input<Tensor>(2);
345-
const Tensor* t_input_metadata = context->Input<Tensor>(5);
346-
InputMetadata* input_metadata = reinterpret_cast<InputMetadata*>(t_input_metadata->Data<int64_t>()[0]);
347348

348349
const Tensor* bias = nullptr;
349350
const Tensor* relative_position_bias = nullptr;
@@ -402,27 +403,120 @@ Status PagedAttention<T>::RunMultiHeadAttention(Tensor* output, OpKernelContext*
402403
data.source_qkv_format = (key == nullptr) ? AttentionQkvFormat::QKV_TN3H : AttentionQkvFormat::Q_K_V_TNH;
403404
return QkvToContext<CudaT>(device_prop, cublas, this->Stream(context), parameters, data);
404405
}
406+
407+
InputMetadata* GetOrCreateMedataFromInput(OpKernelContext* context, InputMetadata* s_input_metadata, int8_t* meta_data_space) {
408+
const Tensor* t_input_metadata = context->Input<Tensor>(5);
409+
if (t_input_metadata && t_input_metadata->Data<int64_t>()[0]) {
410+
return reinterpret_cast<InputMetadata*>(t_input_metadata->Data<int64_t>()[0]);
411+
}
412+
413+
const Tensor* query = context->Input<Tensor>(0);
414+
const Tensor* key_cache = context->Input<Tensor>(3);
415+
int seq_len = query->Shape().NumDimensions() == 3 ? query->Shape()[1] : query->Shape()[0];
416+
417+
const Tensor* positions = context->Input<Tensor>(6);
418+
std::vector<int64_t> cpu_position(positions->Shape().Size());
419+
CUDA_CALL_THROW(cudaMemcpy(cpu_position.data(), positions->DataRaw(),
420+
positions->SizeInBytes(), cudaMemcpyDeviceToHost));
421+
while (cpu_position.back() == 0) {
422+
cpu_position.pop_back();
423+
seq_len--;
424+
}
425+
InputMetadata* input_metadata = s_input_metadata;
426+
input_metadata->num_valid_tokens = seq_len;
427+
428+
std::vector<int32_t> slot_mapping;
429+
std::vector<int32_t> context_lens;
430+
std::vector<int32_t> block_tables;
431+
std::vector<int32_t> seqstart;
432+
433+
input_metadata->max_context_len = 0;
434+
// in prompt mode
435+
if (cpu_position.back() == 0 ||
436+
cpu_position.size() > 1) {
437+
input_metadata->num_prompt_tokens = seq_len;
438+
input_metadata->num_generation_tokens = 0;
439+
slot_mapping.resize(input_metadata->num_prompt_tokens);
440+
std::iota(slot_mapping.begin(), slot_mapping.end(), 0);
441+
} else {
442+
int32_t block_size = gsl::narrow<int32_t>(key_cache->Shape()[3]);
443+
int32_t past_seq_len = cpu_position.back();
444+
input_metadata->num_prompt_tokens = 0;
445+
input_metadata->num_generation_tokens = seq_len;
446+
slot_mapping.push_back(past_seq_len);
447+
context_lens.push_back(past_seq_len + 1);
448+
for (int i = 0; i < past_seq_len + 1; i += block_size) {
449+
block_tables.push_back(i / block_size);
450+
}
451+
input_metadata->max_context_len = context_lens.back();
452+
}
453+
454+
if (block_tables.empty()) {
455+
input_metadata->block_tables = 0;
456+
} else {
457+
// copy to cuda
458+
CUDA_CALL_THROW(cudaMemcpy(meta_data_space, block_tables.data(),
459+
block_tables.size() * sizeof(int32_t), cudaMemcpyHostToDevice));
460+
input_metadata->block_tables = reinterpret_cast<int64_t>(meta_data_space);
461+
meta_data_space += block_tables.size() * sizeof(int32_t);
462+
}
463+
if (context_lens.empty()) {
464+
input_metadata->context_lens = 0;
465+
} else {
466+
// copy to cuda
467+
CUDA_CALL_THROW(cudaMemcpy(meta_data_space, context_lens.data(),
468+
context_lens.size() * sizeof(int32_t), cudaMemcpyHostToDevice));
469+
input_metadata->context_lens = reinterpret_cast<int64_t>(meta_data_space);
470+
meta_data_space += context_lens.size() * sizeof(int32_t);
471+
}
472+
{
473+
// copy to cuda
474+
CUDA_CALL_THROW(cudaMemcpy(meta_data_space, slot_mapping.data(),
475+
slot_mapping.size() * sizeof(int32_t), cudaMemcpyHostToDevice));
476+
input_metadata->slot_mapping = reinterpret_cast<int64_t>(meta_data_space);
477+
meta_data_space += slot_mapping.size() * sizeof(int32_t);
478+
}
479+
input_metadata->max_num_blocks_per_seq = block_tables.size();
480+
std::memset(input_metadata->cache_events.events, 0, sizeof(THEvent));
481+
482+
if (input_metadata->num_prompt_tokens > 0) {
483+
seqstart.push_back(0);
484+
seqstart.push_back(input_metadata->num_prompt_tokens);
485+
486+
// copy to cuda
487+
CUDA_CALL_THROW(cudaMemcpy(meta_data_space, seqstart.data(),
488+
seqstart.size() * sizeof(int32_t), cudaMemcpyHostToDevice));
489+
input_metadata->attn_bias.q_seqinfo.seqstart = reinterpret_cast<int64_t>(meta_data_space);
490+
input_metadata->attn_bias.q_seqinfo.max_seqlen = input_metadata->num_prompt_tokens;
491+
input_metadata->attn_bias.batchsize = 1;
492+
meta_data_space += seqstart.size() * sizeof(int32_t);
493+
}
494+
495+
return input_metadata;
496+
}
497+
405498
template <typename T>
406499
Status PagedAttention<T>::ComputeInternal(OpKernelContext* context) const {
407500
const Tensor* query = context->Input<Tensor>(0);
408501
const Tensor* key = context->Input<Tensor>(1);
409502
const Tensor* value = context->Input<Tensor>(2);
410503
const Tensor* key_cache = context->Input<Tensor>(3);
411504
const Tensor* value_cache = context->Input<Tensor>(4);
412-
const Tensor* t_input_metadata = context->Input<Tensor>(5);
413505
const Tensor* positions = context->Input<Tensor>(6);
414506
const Tensor* cos_sin_cache = context->Input<Tensor>(7);
415507
const Tensor* kv_quant_param = (context->InputCount() > 8) ? context->Input<Tensor>(8) : nullptr;
416508
ORT_UNUSED_PARAMETER(kv_quant_param);
417509

418-
InputMetadata* input_metadata = reinterpret_cast<InputMetadata*>(t_input_metadata->Data<int64_t>()[0]);
510+
int seq_len = query->Shape().NumDimensions() == 3? query->Shape()[1] : query->Shape()[0];
511+
auto meta_data_space = this->template GetScratchBuffer<int8_t>(std::max(1024, seq_len * 3) * sizeof(int32_t), context->GetComputeStream());
512+
InputMetadata self_build_input_metadata;
513+
InputMetadata* input_metadata = GetOrCreateMedataFromInput(context, &self_build_input_metadata, meta_data_space.get());
419514

420-
const auto& device_prop = GetDeviceProp();
421515
PackedAttentionParameters parameters;
422-
ORT_RETURN_IF_ERROR(CheckInputs(query, key, value, input_metadata, parameters));
516+
ORT_RETURN_IF_ERROR(CheckInputs(context, input_metadata, parameters));
423517

424518
IAllocatorUniquePtr<T> gemm_buffer;
425-
ORT_RETURN_IF_ERROR(DoQKVProjectionIfNeed(context, parameters, gemm_buffer));
519+
ORT_RETURN_IF_ERROR(DoQKVProjectionIfNeed(context, input_metadata, parameters, gemm_buffer));
426520

427521
T* query_data = const_cast<T*>(query->Data<T>());
428522
T* key_data = const_cast<T*>(key->Data<T>());
@@ -459,20 +553,9 @@ Status PagedAttention<T>::ComputeInternal(OpKernelContext* context) const {
459553
}
460554

461555
int64_t num_prompt_tokens = input_metadata->num_prompt_tokens;
462-
bool use_multihead_attn = ParseEnvironmentVariableWithDefault<bool>("use_multihead_attn", true);
463-
464556
if (num_prompt_tokens > 0) {
465-
if (use_multihead_attn) {
466-
ORT_RETURN_IF_ERROR(RunMultiHeadAttention(output, context, parameters, gemm_buffer));
467-
} else {
468-
MemoryEfficientAttn<MLFloat16>(device_prop, Stream(context),
469-
query,
470-
key,
471-
value,
472-
output,
473-
input_metadata,
474-
parameters);
475-
}
557+
ORT_RETURN_IF_ERROR(RunMultiHeadAttention(output, context, input_metadata, parameters, gemm_buffer));
558+
CHECK_CUDA_ERROR();
476559
}
477560

478561
auto key_cache_shape = key_cache->Shape();

onnxruntime/contrib_ops/cuda/bert/paged_attention.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,12 @@ class PagedAttention final : public TrtFusedAttention<T>, public CudaKernel {
5555

5656
private:
5757
Status CheckInputs(
58-
const Tensor* query,
59-
const Tensor* key,
60-
const Tensor* value,
58+
OpKernelContext* context,
6159
const InputMetadata* input_metadata,
6260
PackedAttentionParameters& parameters) const;
63-
Status RunMultiHeadAttention(Tensor* output, OpKernelContext* context, PackedAttentionParameters parameters, IAllocatorUniquePtr<T>& gemm_buffer) const;
64-
Status DoQKVProjectionIfNeed(OpKernelContext* context, PackedAttentionParameters parameters,
61+
Status RunMultiHeadAttention(Tensor* output, OpKernelContext* context, InputMetadata* input_metadata,
62+
PackedAttentionParameters parameters, IAllocatorUniquePtr<T>& gemm_buffer) const;
63+
Status DoQKVProjectionIfNeed(OpKernelContext* context, InputMetadata* input_metadata, PackedAttentionParameters parameters,
6564
IAllocatorUniquePtr<T>& gemm_buffer) const;
6665

6766
int32_t num_heads_; // number of attention heads

onnxruntime/core/graph/contrib_ops/bert_defs.cc

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,20 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
658658
PackedMultiHeadAttentionTypeAndShapeInference(ctx);
659659
}));
660660

661+
void propagateShapeAndTypeFromFirstInputAndParam(ONNX_NAMESPACE::InferenceContext& ctx) {
662+
propagateShapeAndTypeFromFirstInput(ctx);
663+
// fix output_shape
664+
auto* output_shape =
665+
ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
666+
667+
for (int i = 0; i < output_shape->dim_size(); i++) {
668+
auto* dim_i = output_shape->mutable_dim(i);
669+
if (dim_i->has_dim_param() && dim_i->dim_value() == 0) {
670+
dim_i->set_dim_value(-1);
671+
}
672+
}
673+
}
674+
661675
constexpr const char* PagedAttention_ver1_doc = R"DOC(
662676
PagedAttention is from https://vllm.ai/
663677
It consists of two types of attention.
@@ -686,17 +700,18 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
686700
.Input(2, "value", "The input V-Tensor with shape(batch,seqlen,num-heads, head-size).", "T")
687701
.Input(3, "key_cache", "Blocked key cache in this layer.", "T2")
688702
.Input(4, "value_cache", "Blocked value cache in this layer.", "T2")
689-
.Input(5, "input_metadata", "Block mapping for each token, and some other eseential infos in InputMetadata, This input Tensor has shape [1], the value is a pointer of struct InputMetadata. It should be converted into a class and used then", "T1")
703+
.Input(5, "input_metadata", "Block mapping for each token, and some other eseential infos in InputMetadata, This input Tensor has shape [1], the value is a pointer of struct InputMetadata. It should be converted into a class and used then", "T1", OpSchema::Optional)
690704
.Input(6, "positions", "positions used for RoPE embedding", "T1", OpSchema::Optional)
691-
.Input(7, "cos_sin_cache", "cos_sin_cache used for RoPE embedding", "T", OpSchema::Optional)
705+
.Input(7, "cos_sin_cache_or_alibi_bais", "cos_sin_cache used for RoPE embedding, alibi for alibi embinding", "T3", OpSchema::Optional)
692706
.Input(8, "kv_quant_param", "quantization param for kvcache, like scale and zeropoint", "T", OpSchema::Optional)
693707
.Output(0, "output", "Attention output", "T")
694708
.TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(bfloat16)"},
695709
"Constrain input and output types to float/ tensors.")
696710
.TypeConstraint("T1", {"tensor(int64)"}, "Constrain input META types to pointer tensors.")
697711
.TypeConstraint("T2", {"tensor(int8)", "tensor(float16)", "tensor(float)", "tensor(bfloat16)"}, "kvcache and quant scale")
712+
.TypeConstraint("T3", {"tensor(float16)", "tensor(float)", "tensor(bfloat16)"}, "alibi scopt or cos_sin_cache")
698713
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
699-
propagateShapeAndTypeFromFirstInput(ctx);
714+
propagateShapeAndTypeFromFirstInputAndParam(ctx);
700715
}));
701716

702717
ONNX_MS_OPERATOR_SET_SCHEMA(
@@ -709,7 +724,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
709724
.Input(1, "weight", "2D input tensor with shape (hidden_size)", "T")
710725
.Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", "T")
711726
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or half tensors.")
712-
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput));
727+
.TypeAndShapeInferenceFunction(
728+
[](ONNX_NAMESPACE::InferenceContext& ctx) {
729+
propagateShapeAndTypeFromFirstInputAndParam(ctx);
730+
}));
713731

714732
void SiluMulShapeInfer(InferenceContext& ctx) {
715733
auto* output_shape =

0 commit comments

Comments
 (0)