Skip to content

Commit 80536a5

Browse files
author
Vladimir Paramuzov
committed
2 kernels
1 parent 601a6e9 commit 80536a5

File tree

2 files changed

+75
-53
lines changed

2 files changed

+75
-53
lines changed

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_micro.cpp

Lines changed: 66 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
#include "common_tools.h"
77
#include "common_types.h"
88
#include "jitter.h"
9+
#include "kernel_selector_common.h"
910
#include "kernel_selector_params.h"
1011
#include "micro_utils.hpp"
1112
#include "tensor_type.h"
1213

13-
1414
#include <algorithm>
1515
#include <mutex>
1616
#include <string>
@@ -178,7 +178,7 @@ sdpa_config_t *choose_config_xehpc(int head_size, int seq, bool thin_q) {
178178

179179
std::mutex SDPAKernelMicro::m;
180180

181-
void SDPAKernelMicro::init_microkernels(const sdpa_params& params) const {
181+
void SDPAKernelMicro::init_microkernels(const sdpa_params& params, micro::Package& gemm_kq, micro::Package& gemm_vs, bool is_prefill) const {
182182
std::lock_guard<std::mutex> l(m);
183183
const auto& Q = params.inputs[0];
184184
const auto& K = params.inputs[1];
@@ -187,14 +187,14 @@ void SDPAKernelMicro::init_microkernels(const sdpa_params& params) const {
187187
auto& out = params.outputs[0];
188188
const auto head_size = params.conf.head_size;
189189
const auto d_max = get_d_max(head_size);
190-
const Tensor::Dim n_keys = K.X();
190+
const Tensor::Dim n_keys = K.X().v; //get_seq_length(K, params.input1_order);
191191
const Tensor::Dim n_queries = get_seq_length(Q, params.input0_order);
192192
const Tensor::Dim n_values = V.X();
193193
const auto batch = out.Batch().v * out.Feature().v;
194194

195195
/* Retrieve pre-tuned kernel configuration */
196196
sdpa_config_t *config = nullptr;
197-
bool thin_q = !n_queries.is_dynamic && (n_queries.v <= 16);
197+
bool thin_q = (!n_queries.is_dynamic && (n_queries.v <= 16)) || !is_prefill;
198198

199199
switch (params.engineInfo.arch) {
200200
case gpu_arch::xe_hpg: {
@@ -223,7 +223,7 @@ void SDPAKernelMicro::init_microkernels(const sdpa_params& params) const {
223223
problem.Ts = problem.Tc;
224224

225225
auto problem_kq = problem;
226-
problem_kq.A.layout = micro::MatrixLayout::T; // TODO: support transpose with MatrixLayout::N layout
226+
problem_kq.A.layout = micro::MatrixLayout::T;
227227
problem_kq.B.layout = micro::MatrixLayout::Pr;
228228
problem_kq.C.layout = micro::MatrixLayout::T;
229229
problem_kq.A.setAlignment(micro::alignment_for_ld(head_size * problem.Ta));
@@ -252,7 +252,7 @@ void SDPAKernelMicro::init_microkernels(const sdpa_params& params) const {
252252
opts_kq.slmPtr = true;
253253

254254
/* Ask microkernel provider for microkernel */
255-
gemm_kq = selectGEMMMicrokernel(opts_kq, hw_info, sizes, problem_kq, reqs_kq);
255+
gemm_kq = micro::select_gemm_microkernel(opts_kq, hw_info, sizes, problem_kq, reqs_kq);
256256

257257
/* Update for second GEMM: V*S */
258258
auto problem_vs = problem;
@@ -334,7 +334,7 @@ bool SDPAKernelMicro::Validate(const Params& p) const {
334334
return true;
335335
}
336336

337-
JitConstants SDPAKernelMicro::GetJitConstants(const sdpa_params& params) const {
337+
JitConstants SDPAKernelMicro::GetJitConstants(const sdpa_params& params, micro::Package& gemm_kq, micro::Package& gemm_vs) const {
338338
auto jit = MakeBaseParamsJitConstants(params);
339339
const auto& prim_params = dynamic_cast<const sdpa_params&>(params);
340340

@@ -452,7 +452,7 @@ JitConstants SDPAKernelMicro::GetJitConstants(const sdpa_params& params) const {
452452
return jit;
453453
}
454454

455-
CommonDispatchData SDPAKernelMicro::SetDefault(const sdpa_params& params) const {
455+
CommonDispatchData SDPAKernelMicro::SetDefault(const sdpa_params& params, micro::Package& gemm_kq, micro::Package& gemm_vs) const {
456456
CommonDispatchData dispatch_data;
457457

458458
auto wg_tile_q = gemm_kq.getSetting("wg_tile_n");
@@ -468,28 +468,17 @@ CommonDispatchData SDPAKernelMicro::SetDefault(const sdpa_params& params) const
468468
return dispatch_data;
469469
}
470470

471-
KernelsData SDPAKernelMicro::GetKernelsData(const Params& params) const {
472-
KernelData kd = KernelData::Default<sdpa_params>(params);
473-
const auto& prim_params = dynamic_cast<const sdpa_params&>(params);
474-
475-
if (!Validate(params)) {
476-
return {};
477-
}
478-
479-
init_microkernels(prim_params);
480-
481-
auto dispatchData = SetDefault(prim_params);
482-
auto entry_point = GetEntryPoint(kernelName, prim_params.layerID, params);
483-
auto cldnn_jit = GetJitConstants(prim_params);
484-
auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
471+
clKernelData SDPAKernelMicro::get_kernel_data(const sdpa_params& params, micro::Package& gemm_kq, micro::Package& gemm_vs, bool is_prefill) const {
472+
auto name = kernelName + (is_prefill ? "_prefill" : "_generate");
473+
init_microkernels(params, gemm_kq, gemm_vs, is_prefill);
474+
auto dispatch_data = SetDefault(params, gemm_kq, gemm_vs);
475+
auto entry_point = GetEntryPoint(name, params.layerID, params);
476+
auto jit = CreateJit(name, GetJitConstants(params, gemm_kq, gemm_vs), entry_point);
477+
clKernelData kernel;
485478

486-
auto& kernel = kd.kernels[0];
487-
488-
GetUpdateDispatchDataFunc(kd);
489-
490-
FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point,
491-
"", false, false, static_cast<int>(prim_params.inputs.size()),
492-
GetFusedPrimitiveInputsCount(params), 1, prim_params.is_shape_agnostic);
479+
FillCLKernelData(kernel, dispatch_data, params.engineInfo, kernelName, jit, entry_point,
480+
"", false, false, static_cast<int>(params.inputs.size()),
481+
GetFusedPrimitiveInputsCount(params), 1, params.is_shape_agnostic);
493482

494483
kernel.params.arguments.clear();
495484
if (params.is_shape_agnostic )
@@ -500,9 +489,9 @@ KernelsData SDPAKernelMicro::GetKernelsData(const Params& params) const {
500489
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, 2}); // V
501490
kernel.params.arguments.push_back({ArgumentDescriptor::Types::OUTPUT, 0}); // A
502491

503-
if (prim_params.inputs.size() >= 4)
492+
if (params.inputs.size() >= 4)
504493
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, 3}); // mask
505-
if (prim_params.inputs.size() >= 5)
494+
if (params.inputs.size() >= 5)
506495
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, 4}); // Scale
507496

508497
kernel.params.arguments.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // D
@@ -530,23 +519,44 @@ KernelsData SDPAKernelMicro::GetKernelsData(const Params& params) const {
530519
shim_options.useTileOps = true;
531520
shim_options.decorator = "kq";
532521

533-
kd.kernels[0].code.kernelString->jit += generateShim(gemm_kq, micro::HostLanguage::OpenCL_C, shim_options);
522+
kernel.code.kernelString->jit += generateShim(gemm_kq, micro::HostLanguage::OpenCL_C, shim_options);
534523

535524
shim_options.microkernelID++;
536525
shim_options.decorator = "vs";
537-
kd.kernels[0].code.kernelString->jit += generateShim(gemm_vs, micro::HostLanguage::OpenCL_C, shim_options);
526+
kernel.code.kernelString->jit += generateShim(gemm_vs, micro::HostLanguage::OpenCL_C, shim_options);
538527

539528
if (gemm_kq.grfMin > 128 || gemm_vs.grfMin > 128)
540-
kd.kernels[0].code.kernelString->options += " -cl-intel-256-GRF-per-thread";
529+
kernel.code.kernelString->options += " -cl-intel-256-GRF-per-thread";
541530

542531
std::string extra_options = " -Dcl_intel_dot_accumulate";
543532
extra_options += " -Dcl_intel_global_float_atomic";
544533
extra_options += " -Dcl_intel_subgroup_matrix_multiply_accumulate";
545534
extra_options += " -Dcl_intel_subgroup_split_matrix_multiply_accumulate";
546-
kd.kernels[0].code.kernelString->options += extra_options;
535+
kernel.code.kernelString->options += extra_options;
547536

548-
kd.kernels[0].code.kernelString->batch_compilation = false;
549-
kd.kernels[0].code.kernelString->has_microkernels = true;
537+
kernel.code.kernelString->batch_compilation = false;
538+
kernel.code.kernelString->has_microkernels = true;
539+
540+
return kernel;
541+
}
542+
543+
KernelsData SDPAKernelMicro::GetKernelsData(const Params& params) const {
544+
const size_t num_kernels = 2;
545+
KernelData kd = KernelData::Default<sdpa_params>(params, num_kernels);
546+
const auto& prim_params = dynamic_cast<const sdpa_params&>(params);
547+
548+
if (!Validate(params)) {
549+
return {};
550+
}
551+
552+
gemms_kq.resize(2);
553+
gemms_vs.resize(2);
554+
555+
for (size_t i = 0; i < num_kernels; i++) {
556+
kd.kernels[i] = get_kernel_data(prim_params, gemms_kq[i], gemms_vs[i], i == prefill_id);
557+
}
558+
559+
GetUpdateDispatchDataFunc(kd);
550560

551561
std::cerr << prim_params.layerID << " use micro_sdpa!\n";
552562
return { kd };
@@ -555,20 +565,14 @@ KernelsData SDPAKernelMicro::GetKernelsData(const Params& params) const {
555565
void SDPAKernelMicro::GetUpdateDispatchDataFunc(KernelData& kd) const {
556566
kd.update_dispatch_data_func = [this](const Params& params, KernelData& kernel_data) {
557567
const auto& prim_params = static_cast<const sdpa_params&>(params);
558-
auto dispatchData = SetDefault(prim_params);
559-
OPENVINO_ASSERT(kernel_data.kernels.size() == 1, "[GPU] Invalid kernels size for update dispatch data func");
560-
kernel_data.kernels[0].params.workGroups.global = dispatchData.gws;
561-
kernel_data.kernels[0].params.workGroups.local = dispatchData.lws;
562-
kernel_data.kernels[0].skip_execution = KernelData::SkipKernelExecution(prim_params);
563-
564-
auto head_size = prim_params.conf.head_size;
565-
566568
const auto& Q = prim_params.inputs[0];
567569
const auto& K = prim_params.inputs[1];
568570

569571
const auto n_queries = get_seq_length(Q, prim_params.input0_order);
570572
const auto n_keys = get_seq_length(K, prim_params.input1_order);
571573

574+
auto head_size = prim_params.conf.head_size;
575+
572576
ScalarDescriptor s_d;
573577
s_d.t = ScalarDescriptor::Types::INT32;
574578
s_d.v.s32 = static_cast<uint32_t>(head_size);
@@ -581,11 +585,24 @@ void SDPAKernelMicro::GetUpdateDispatchDataFunc(KernelData& kd) const {
581585
s_q.t = ScalarDescriptor::Types::INT32;
582586
s_q.v.s32 = static_cast<uint32_t>(n_queries.v);
583587

588+
const bool is_prefill = true;//n_queries.v > 1;
589+
590+
OPENVINO_ASSERT(kernel_data.kernels.size() == 2, "[GPU] Invalid kernels size for update dispatch data func");
591+
592+
size_t target_kernel = is_prefill ? prefill_id : generate_id;
593+
594+
kernel_data.kernels[prefill_id].skip_execution = true;
595+
kernel_data.kernels[generate_id].skip_execution = true;
596+
597+
auto dispatchData = SetDefault(prim_params, gemms_kq[target_kernel], gemms_vs[target_kernel]);
598+
kernel_data.kernels[target_kernel].params.workGroups.global = dispatchData.gws;
599+
kernel_data.kernels[target_kernel].params.workGroups.local = dispatchData.lws;
600+
kernel_data.kernels[target_kernel].skip_execution = KernelData::SkipKernelExecution(prim_params);
584601

585-
kernel_data.kernels[0].params.scalars.clear();
586-
kernel_data.kernels[0].params.scalars.push_back(s_d);
587-
kernel_data.kernels[0].params.scalars.push_back(s_k);
588-
kernel_data.kernels[0].params.scalars.push_back(s_q);
602+
kernel_data.kernels[target_kernel].params.scalars.clear();
603+
kernel_data.kernels[target_kernel].params.scalars.push_back(s_d);
604+
kernel_data.kernels[target_kernel].params.scalars.push_back(s_k);
605+
kernel_data.kernels[target_kernel].params.scalars.push_back(s_q);
589606
};
590607
}
591608

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_micro.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,21 @@ class SDPAKernelMicro : public SDPAKernelBase {
2121
protected:
2222
bool Validate(const Params& p) const override;
2323
void GetUpdateDispatchDataFunc(KernelData& kd) const override;
24-
CommonDispatchData SetDefault(const sdpa_params& params) const;
25-
JitConstants GetJitConstants(const sdpa_params& params) const;
24+
CommonDispatchData SetDefault(const sdpa_params& params, micro::Package& gemm_kq, micro::Package& gemm_vs) const;
25+
JitConstants GetJitConstants(const sdpa_params& params, micro::Package& gemm_kq, micro::Package& gemm_vs) const;
2626
std::vector<FusedOpType> GetSupportedFusedOps() const override {
2727
return {};
2828
}
2929

30-
void init_microkernels(const sdpa_params& params) const;
30+
void init_microkernels(const sdpa_params& params, micro::Package& gemm_kq, micro::Package& gemm_vs, bool is_prefill) const;
31+
clKernelData get_kernel_data(const sdpa_params& params, micro::Package& gemm_kq, micro::Package& gemm_vs, bool is_prefill) const;
3132

3233
private:
33-
mutable micro::Package gemm_kq, gemm_vs;
34+
mutable std::vector<micro::Package> gemms_kq;
35+
mutable std::vector<micro::Package> gemms_vs;
36+
37+
static constexpr size_t prefill_id = 0;
38+
static constexpr size_t generate_id = 1;
3439
static std::mutex m;
3540
};
3641
} // namespace kernel_selector

0 commit comments

Comments
 (0)