Skip to content

Commit 9a2ba9d

Browse files
aciddelgadokleiti
authored andcommitted
GQA Memory Efficient Kernel (microsoft#17920)
Implement Cutlass Memory Efficient Attention Kernel into Group Query Attention Operator. ### Motivation and Context Before this change, Group Query Attention Operator was supported only by Flash-Attention. While this is the most efficient kernel for the operation, it only supports sm >= 80. Cutlass Memory Efficient Attention Kernel supports sm >= 53, allowing us to support a broader range of GPU hardware.
1 parent aca0171 commit 9a2ba9d

14 files changed

Lines changed: 843 additions & 312 deletions

cmake/onnxruntime_rocm_hipify.cmake

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,11 @@ set(contrib_ops_excluded_files
9494
"cuda_contrib_kernels.h"
9595
"inverse.cc"
9696
"fused_conv.cc"
97+
"bert/group_query_attention_helper.h"
98+
"bert/group_query_attention.h"
99+
"bert/group_query_attention.cc"
100+
"bert/group_query_attention_impl.h"
101+
"bert/group_query_attention_impl.cu"
97102
)
98103

99104
if (NOT onnxruntime_ENABLE_ATEN)

docs/ContribOperators.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2422,14 +2422,14 @@ This version of the operator has been available since version 1 of the 'com.micr
24222422
<dd>When buffered past_key and past_value is used (present_key uses same tensor as past_key), requiredto specify past_sequence_length (could be 0). Otherwise, past_sequence_length inferred from past_key.</dd>
24232423
</dl>
24242424

2425-
#### Outputs (1 - 3)
2425+
#### Outputs
24262426

24272427
<dl>
24282428
<dt><tt>output</tt> : T</dt>
24292429
<dd>3D output tensor with shape (batch_size, sequence_length, hidden_size)</dd>
2430-
<dt><tt>present_key</tt> (optional) : T</dt>
2430+
<dt><tt>present_key</tt> : T</dt>
24312431
<dd>present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
2432-
<dt><tt>present_value</tt> (optional) : T</dt>
2432+
<dt><tt>present_value</tt> : T</dt>
24332433
<dd>present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
24342434
</dl>
24352435

onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ Status EfficientAttention(
374374
p.num_heads = parameters.num_heads;
375375
p.sequence_length = parameters.sequence_length;
376376
p.kv_sequence_length = parameters.total_sequence_length;
377+
p.max_sequence_length = parameters.total_sequence_length;
377378
p.qk_head_size = parameters.head_size;
378379
p.v_head_size = parameters.v_head_size;
379380
p.causal = parameters.is_unidirectional;
@@ -395,6 +396,7 @@ Status EfficientAttention(
395396
p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias;
396397
p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias;
397398
p.output = data.output;
399+
p.is_kv_bsnh = true;
398400
p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float))
399401
? data.scratch
400402
: nullptr;

onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -51,25 +51,45 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
5151
p.num_keys = params.kv_sequence_length;
5252

5353
if (params.causal) {
54-
p.custom_mask_type = Attention::CausalFromTopLeft;
54+
p.custom_mask_type = Attention::CausalFromBottomRight;
5555
}
5656

57-
// Input format is BxSxNxH, output is BxSxNxH
58-
p.q_strideH = params.qk_head_size;
59-
p.k_strideH = params.qk_head_size;
60-
p.v_strideH = params.v_head_size;
61-
p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys;
62-
63-
p.q_strideM = params.num_heads * params.qk_head_size;
64-
p.k_strideM = params.num_heads * params.qk_head_size;
65-
p.v_strideM = params.num_heads * params.v_head_size;
66-
p.o_strideM = params.num_heads * params.v_head_size;
67-
p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys;
68-
69-
p.q_strideB = static_cast<int64_t>(p.q_strideM) * params.sequence_length;
70-
p.k_strideB = static_cast<int64_t>(p.k_strideM) * params.kv_sequence_length;
71-
p.v_strideB = static_cast<int64_t>(p.v_strideM) * params.kv_sequence_length;
72-
p.bias_strideB = params.is_attn_bias_batched ? static_cast<int64_t>(p.bias_strideH) * params.num_heads : 0;
57+
// We use max_sequence_length to calculate KV stride
58+
if (params.is_kv_bsnh) {
59+
// Input Q, K, V format is BxSxNxH, output is BxSxNxH
60+
p.q_strideH = params.qk_head_size;
61+
p.k_strideH = params.qk_head_size;
62+
p.v_strideH = params.v_head_size;
63+
p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys;
64+
65+
p.q_strideM = params.num_heads * params.qk_head_size;
66+
p.k_strideM = params.num_heads * params.qk_head_size;
67+
p.v_strideM = params.num_heads * params.v_head_size;
68+
p.o_strideM = params.num_heads * params.v_head_size;
69+
p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys;
70+
71+
p.q_strideB = static_cast<int64_t>(p.q_strideM) * params.sequence_length;
72+
p.k_strideB = static_cast<int64_t>(p.k_strideM) * params.max_sequence_length;
73+
p.v_strideB = static_cast<int64_t>(p.v_strideM) * params.max_sequence_length;
74+
p.bias_strideB = params.is_attn_bias_batched ? static_cast<int64_t>(p.bias_strideH) * params.num_heads : 0;
75+
} else {
76+
// Input K, V format is BxNxSxH, Input Q is BxSxNxH, output is BxSxNxH
77+
p.q_strideH = params.qk_head_size;
78+
p.k_strideH = params.max_sequence_length * params.qk_head_size;
79+
p.v_strideH = params.max_sequence_length * params.v_head_size;
80+
p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys;
81+
82+
p.q_strideM = params.num_heads * params.qk_head_size;
83+
p.k_strideM = params.qk_head_size;
84+
p.v_strideM = params.v_head_size;
85+
p.o_strideM = params.num_heads * params.v_head_size;
86+
p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys;
87+
88+
p.q_strideB = params.num_heads * params.qk_head_size * params.sequence_length;
89+
p.k_strideB = params.num_heads * params.qk_head_size * params.max_sequence_length;
90+
p.v_strideB = params.num_heads * params.v_head_size * params.max_sequence_length;
91+
p.bias_strideB = params.is_attn_bias_batched ? static_cast<int64_t>(p.bias_strideH) * params.num_heads : 0;
92+
}
7393
}
7494

7595
constexpr auto kernel_fn = attention_kernel_batched_impl<Attention>;

onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@ namespace cuda {
1414
struct MemoryEfficientAttentionParams {
1515
int32_t sm;
1616
bool is_half;
17+
bool is_kv_bsnh = true;
1718
int32_t batch_size;
1819
int32_t num_heads;
1920
int32_t sequence_length;
2021
int32_t kv_sequence_length;
22+
int32_t max_sequence_length;
2123
int32_t qk_head_size;
2224
int32_t v_head_size;
2325
bool causal;

onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66
#include "contrib_ops/cuda/bert/group_query_attention_impl.h"
77
#include "contrib_ops/cuda/bert/group_query_attention.h"
88
#include "contrib_ops/cuda/bert/group_query_attention_helper.h"
9+
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
910
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
10-
// #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
11-
// #include "contrib_ops/cpu/utils/console_dumper.h"
1211

1312
using namespace onnxruntime::cuda;
1413
using namespace ::onnxruntime::common;
@@ -55,6 +54,13 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
5554
#else
5655
disable_flash_attention_ = true;
5756
#endif
57+
58+
#if USE_MEMORY_EFFICIENT_ATTENTION
59+
disable_memory_efficient_attention_ = sizeof(T) != 2 ||
60+
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
61+
#else
62+
disable_memory_efficient_attention_ = true;
63+
#endif
5864
}
5965

6066
template <typename T>
@@ -92,18 +98,6 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
9298
output_shape[2] = static_cast<int64_t>(parameters.hidden_size);
9399
Tensor* output = context->Output(0, output_shape);
94100

95-
std::vector<int64_t> present_dims;
96-
if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) {
97-
present_dims = {
98-
parameters.batch_size, parameters.present_sequence_length, parameters.kv_num_heads, parameters.head_size};
99-
} else { // BNSH
100-
present_dims = {
101-
parameters.batch_size, parameters.kv_num_heads, parameters.present_sequence_length, parameters.head_size};
102-
}
103-
TensorShape present_shape(present_dims);
104-
Tensor* present_key = context->Output(1, present_shape);
105-
Tensor* present_value = context->Output(2, present_shape);
106-
107101
#if USE_FLASH_ATTENTION
108102
bool use_flash_attention = !disable_flash_attention_ &&
109103
onnxruntime::flash::is_supported(device_prop,
@@ -143,8 +137,47 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
143137
auto seqlens_k_buffer = GetScratchBuffer<void>(0, context->GetComputeStream()); // nullptr
144138
#endif
145139

146-
// only kernel implemented for gqa right now
147-
ORT_ENFORCE(use_flash_attention);
140+
#if USE_MEMORY_EFFICIENT_ATTENTION
141+
int sm = (device_prop.major * 10) + device_prop.minor;
142+
bool use_memory_efficient_attention =
143+
!use_flash_attention &&
144+
!disable_memory_efficient_attention_ &&
145+
(parameters.head_size & 7) == 0 &&
146+
parameters.sequence_length <= parameters.past_sequence_length + parameters.kv_sequence_length &&
147+
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
148+
has_memory_efficient_attention(sm, sizeof(T) == 2);
149+
// allocate buffers
150+
size_t kv_buffer_bytes = 0;
151+
// need a buffer if we must ungroup kv
152+
const bool needs_buff = (parameters.num_heads != parameters.kv_num_heads);
153+
if (use_memory_efficient_attention && needs_buff) {
154+
kv_buffer_bytes = (sizeof(T) * parameters.batch_size * parameters.num_heads * (parameters.past_sequence_length + parameters.kv_sequence_length) * parameters.head_size);
155+
}
156+
size_t fmha_buffer_bytes = 0;
157+
if (use_memory_efficient_attention && MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) {
158+
fmha_buffer_bytes = (parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(float));
159+
}
160+
auto k_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
161+
auto v_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
162+
auto fmha_buffer = GetScratchBuffer<void>(fmha_buffer_bytes, context->GetComputeStream());
163+
#else
164+
constexpr bool use_memory_efficient_attention = false;
165+
auto k_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
166+
auto v_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
167+
auto fmha_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
168+
#endif
169+
170+
std::vector<int64_t> present_dims;
171+
if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) {
172+
present_dims = {
173+
parameters.batch_size, parameters.present_sequence_length, parameters.kv_num_heads, parameters.head_size};
174+
} else { // BNSH
175+
present_dims = {
176+
parameters.batch_size, parameters.kv_num_heads, parameters.present_sequence_length, parameters.head_size};
177+
}
178+
TensorShape present_shape(present_dims);
179+
Tensor* present_key = context->Output(1, present_shape);
180+
Tensor* present_value = context->Output(2, present_shape);
148181

149182
data.query = reinterpret_cast<const CudaT*>(query->Data<T>());
150183
data.key = reinterpret_cast<const CudaT*>(key->Data<T>());
@@ -155,6 +188,7 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
155188
data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast<CudaT*>(present_key->MutableData<T>());
156189
data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast<CudaT*>(present_value->MutableData<T>());
157190
data.use_flash_attention = use_flash_attention;
191+
data.use_memory_efficient_attention = use_memory_efficient_attention;
158192
if (softmax_lse_buffer != nullptr) {
159193
data.softmax_lse = reinterpret_cast<CudaT*>(softmax_lse_buffer.get());
160194
}
@@ -167,6 +201,13 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
167201
if (seqlens_k_buffer != nullptr) {
168202
data.seqlens_k = reinterpret_cast<int*>(seqlens_k_buffer.get());
169203
}
204+
if (k_buffer != nullptr) {
205+
data.k = reinterpret_cast<CudaT*>(k_buffer.get());
206+
data.v = reinterpret_cast<CudaT*>(v_buffer.get());
207+
}
208+
if (fmha_buffer != nullptr) {
209+
data.fmha_buffer = reinterpret_cast<CudaT*>(fmha_buffer.get());
210+
}
170211

171212
cublasHandle_t cublas = GetCublasHandle(context);
172213

onnxruntime/contrib_ops/cuda/bert/group_query_attention.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class GroupQueryAttention final : public CudaKernel {
2727
bool is_past_bsnh_;
2828
float scale_;
2929
bool disable_flash_attention_;
30+
bool disable_memory_efficient_attention_;
3031
};
3132

3233
} // namespace cuda

onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h

Lines changed: 38 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ Status CheckInputs(const Tensor* query,
2929
// query (Q) : (B, S, D)
3030
// key (K) : (B, S+, D_kv)
3131
// value (V) : (B, S+, D_kv)
32+
ORT_UNUSED_PARAMETER(value);
3233

3334
AttentionQkvFormat qkv_format = Q_K_V_BSNH;
3435
AttentionQkvFormat past_kv_format = Q_K_V_BSNH;
3536

3637
const auto& query_dims = query->Shape().GetDims();
3738
const auto& key_dims = key->Shape().GetDims();
38-
const auto& value_dims = value->Shape().GetDims();
3939

4040
if (query_dims.size() != 3) {
4141
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ",
@@ -47,10 +47,8 @@ Status CheckInputs(const Tensor* query,
4747
int q_hidden_size = static_cast<int>(query_dims[2]);
4848
int head_size = static_cast<int>(q_hidden_size) / num_heads;
4949

50-
int kv_sequence_length = sequence_length;
51-
int kv_hidden_size = (key_dims.size() == 3)
52-
? static_cast<int>(key_dims[2])
53-
: (kv_num_heads * static_cast<int>(key_dims[3]));
50+
int kv_sequence_length = static_cast<int>(key_dims[1]);
51+
int kv_hidden_size = static_cast<int>(key_dims[2]);
5452

5553
int max_sequence_length = 0;
5654
if (past_key != nullptr && past_value != nullptr) {
@@ -134,63 +132,49 @@ Status CheckInputs(const Tensor* query,
134132
"Input 'past_key' and 'past_value' shall be both present or both absent");
135133
}
136134

137-
if (key != nullptr) {
138-
const auto& key_dims = key->Shape().GetDims();
139-
if (key_dims.size() != 3) {
140-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
141-
key_dims.size());
142-
}
143-
if (query_dims[0] != key_dims[0]) {
144-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
145-
"Input 'query' and 'key' shall have same dim 0 (batch size)");
146-
}
147-
148-
if (num_heads % kv_num_heads != 0) {
149-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
150-
"num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ",
151-
num_heads % kv_num_heads);
152-
}
153-
if (key_dims[2] != value_dims[2]) {
154-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
155-
"Input 'key' and 'value' shall have same dim 2 (kv_hidden_size)");
156-
}
157-
158-
qkv_format = Q_K_V_BSNH;
159-
kv_sequence_length = static_cast<int>(key_dims[1]);
160-
} else {
135+
if (key_dims.size() != 3) {
136+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
137+
key_dims.size());
138+
}
139+
if (query_dims[0] != key_dims[0]) {
161140
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
162-
"Missing key tensor.");
141+
"Input 'query' and 'key' shall have same dim 0 (batch size)");
163142
}
164143

165-
if (value != nullptr) {
166-
const auto& value_dims = value->Shape().GetDims();
167-
if (value_dims.size() != 3) {
168-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
169-
value_dims.size());
170-
}
144+
if (num_heads % kv_num_heads != 0) {
145+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
146+
"num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ",
147+
num_heads % kv_num_heads);
148+
}
171149

172-
if (query_dims[0] != value_dims[0]) {
173-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
174-
"Input 'query' and 'value' shall have same dim 0 (batch_size)");
175-
}
150+
const auto& value_dims = value->Shape().GetDims();
151+
if (value_dims.size() != 3) {
152+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
153+
value_dims.size());
154+
}
176155

177-
if (static_cast<int64_t>(kv_sequence_length) != value_dims[1]) {
178-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
179-
"Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)");
180-
}
156+
if (query_dims[0] != value_dims[0]) {
157+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
158+
"Input 'query' and 'value' shall have same dim 0 (batch_size)");
159+
}
181160

182-
if (value_dims[2] != kv_hidden_size) {
183-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");
184-
}
185-
} else {
161+
if (static_cast<int64_t>(kv_sequence_length) != value_dims[1]) {
186162
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
187-
"Missing value tensor.");
163+
"Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)");
164+
}
165+
166+
if (value_dims[2] != kv_hidden_size) {
167+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");
188168
}
189169

190170
// When kv-cache, we take past_seq_len as an argument... otherwise we use sequence length of past kv directly.
191171
int32_t past_sequence_length = 0;
192-
int present_sequence_length = 0;
172+
int present_sequence_length = kv_sequence_length;
193173
if (past_seq_len != nullptr) {
174+
if (past_key == nullptr) {
175+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
176+
"Past KV must be present as share-buffer when using past_seq_len pointer.");
177+
}
194178
if (!onnxruntime::IsScalarOr1ElementVector(past_seq_len)) {
195179
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
196180
"past_sequence_length tensor must be of one element when using past kv.");
@@ -200,6 +184,10 @@ Status CheckInputs(const Tensor* query,
200184
} else {
201185
past_sequence_length = static_cast<int32_t>(*((*past_seq_len).template Data<int64_t>()));
202186
}
187+
if (past_sequence_length + kv_sequence_length > max_sequence_length) {
188+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
189+
"KV buffer too small... shall be that max_sequence_length >= past_sequence_length + kv_sequence_length");
190+
}
203191
present_sequence_length = max_sequence_length;
204192
} else if (past_key != nullptr) {
205193
past_sequence_length = max_sequence_length; // this is the length of past_key tensor

0 commit comments

Comments
 (0)