66#include " common_tools.h"
77#include " common_types.h"
88#include " jitter.h"
9+ #include " kernel_selector_common.h"
910#include " kernel_selector_params.h"
1011#include " micro_utils.hpp"
1112#include " tensor_type.h"
1213
13-
1414#include < algorithm>
1515#include < mutex>
1616#include < string>
@@ -178,7 +178,7 @@ sdpa_config_t *choose_config_xehpc(int head_size, int seq, bool thin_q) {
178178
179179std::mutex SDPAKernelMicro::m;
180180
181- void SDPAKernelMicro::init_microkernels (const sdpa_params& params) const {
181+ void SDPAKernelMicro::init_microkernels (const sdpa_params& params, micro::Package& gemm_kq, micro::Package& gemm_vs, bool is_prefill ) const {
182182 std::lock_guard<std::mutex> l (m);
183183 const auto & Q = params.inputs [0 ];
184184 const auto & K = params.inputs [1 ];
@@ -187,14 +187,14 @@ void SDPAKernelMicro::init_microkernels(const sdpa_params& params) const {
187187 auto & out = params.outputs [0 ];
188188 const auto head_size = params.conf .head_size ;
189189 const auto d_max = get_d_max (head_size);
190- const Tensor::Dim n_keys = K.X ();
190+ const Tensor::Dim n_keys = K.X (). v ; // get_seq_length(K, params.input1_order) ;
191191 const Tensor::Dim n_queries = get_seq_length (Q, params.input0_order );
192192 const Tensor::Dim n_values = V.X ();
193193 const auto batch = out.Batch ().v * out.Feature ().v ;
194194
195195 /* Retrieve pre-tuned kernel configuration */
196196 sdpa_config_t *config = nullptr ;
197- bool thin_q = !n_queries.is_dynamic && (n_queries.v <= 16 );
197+ bool thin_q = ( !n_queries.is_dynamic && (n_queries.v <= 16 )) || !is_prefill ;
198198
199199 switch (params.engineInfo .arch ) {
200200 case gpu_arch::xe_hpg: {
@@ -223,7 +223,7 @@ void SDPAKernelMicro::init_microkernels(const sdpa_params& params) const {
223223 problem.Ts = problem.Tc ;
224224
225225 auto problem_kq = problem;
226- problem_kq.A .layout = micro::MatrixLayout::T; // TODO: support transpose with MatrixLayout::N layout
226+ problem_kq.A .layout = micro::MatrixLayout::T;
227227 problem_kq.B .layout = micro::MatrixLayout::Pr;
228228 problem_kq.C .layout = micro::MatrixLayout::T;
229229 problem_kq.A .setAlignment (micro::alignment_for_ld (head_size * problem.Ta ));
@@ -252,7 +252,7 @@ void SDPAKernelMicro::init_microkernels(const sdpa_params& params) const {
252252 opts_kq.slmPtr = true ;
253253
254254 /* Ask microkernel provider for microkernel */
255- gemm_kq = selectGEMMMicrokernel (opts_kq, hw_info, sizes, problem_kq, reqs_kq);
255+ gemm_kq = micro::select_gemm_microkernel (opts_kq, hw_info, sizes, problem_kq, reqs_kq);
256256
257257 /* Update for second GEMM: V*S */
258258 auto problem_vs = problem;
@@ -334,7 +334,7 @@ bool SDPAKernelMicro::Validate(const Params& p) const {
334334 return true ;
335335}
336336
337- JitConstants SDPAKernelMicro::GetJitConstants (const sdpa_params& params) const {
337+ JitConstants SDPAKernelMicro::GetJitConstants (const sdpa_params& params, micro::Package& gemm_kq, micro::Package& gemm_vs ) const {
338338 auto jit = MakeBaseParamsJitConstants (params);
339339 const auto & prim_params = dynamic_cast <const sdpa_params&>(params);
340340
@@ -452,7 +452,7 @@ JitConstants SDPAKernelMicro::GetJitConstants(const sdpa_params& params) const {
452452 return jit;
453453}
454454
455- CommonDispatchData SDPAKernelMicro::SetDefault (const sdpa_params& params) const {
455+ CommonDispatchData SDPAKernelMicro::SetDefault (const sdpa_params& params, micro::Package& gemm_kq, micro::Package& gemm_vs ) const {
456456 CommonDispatchData dispatch_data;
457457
458458 auto wg_tile_q = gemm_kq.getSetting (" wg_tile_n" );
@@ -468,28 +468,17 @@ CommonDispatchData SDPAKernelMicro::SetDefault(const sdpa_params& params) const
468468 return dispatch_data;
469469}
470470
471- KernelsData SDPAKernelMicro::GetKernelsData (const Params& params) const {
472- KernelData kd = KernelData::Default<sdpa_params>(params);
473- const auto & prim_params = dynamic_cast <const sdpa_params&>(params);
474-
475- if (!Validate (params)) {
476- return {};
477- }
478-
479- init_microkernels (prim_params);
480-
481- auto dispatchData = SetDefault (prim_params);
482- auto entry_point = GetEntryPoint (kernelName, prim_params.layerID , params);
483- auto cldnn_jit = GetJitConstants (prim_params);
484- auto jit = CreateJit (kernelName, cldnn_jit, entry_point);
471+ clKernelData SDPAKernelMicro::get_kernel_data (const sdpa_params& params, micro::Package& gemm_kq, micro::Package& gemm_vs, bool is_prefill) const {
472+ auto name = kernelName + (is_prefill ? " _prefill" : " _generate" );
473+ init_microkernels (params, gemm_kq, gemm_vs, is_prefill);
474+ auto dispatch_data = SetDefault (params, gemm_kq, gemm_vs);
475+ auto entry_point = GetEntryPoint (name, params.layerID , params);
476+ auto jit = CreateJit (name, GetJitConstants (params, gemm_kq, gemm_vs), entry_point);
477+ clKernelData kernel;
485478
486- auto & kernel = kd.kernels [0 ];
487-
488- GetUpdateDispatchDataFunc (kd);
489-
490- FillCLKernelData (kernel, dispatchData, params.engineInfo , kernelName, jit, entry_point,
491- " " , false , false , static_cast <int >(prim_params.inputs .size ()),
492- GetFusedPrimitiveInputsCount (params), 1 , prim_params.is_shape_agnostic );
479+ FillCLKernelData (kernel, dispatch_data, params.engineInfo , kernelName, jit, entry_point,
480+ " " , false , false , static_cast <int >(params.inputs .size ()),
481+ GetFusedPrimitiveInputsCount (params), 1 , params.is_shape_agnostic );
493482
494483 kernel.params .arguments .clear ();
495484 if (params.is_shape_agnostic )
@@ -500,9 +489,9 @@ KernelsData SDPAKernelMicro::GetKernelsData(const Params& params) const {
500489 kernel.params .arguments .push_back ({ArgumentDescriptor::Types::INPUT, 2 }); // V
501490 kernel.params .arguments .push_back ({ArgumentDescriptor::Types::OUTPUT, 0 }); // A
502491
503- if (prim_params .inputs .size () >= 4 )
492+ if (params .inputs .size () >= 4 )
504493 kernel.params .arguments .push_back ({ArgumentDescriptor::Types::INPUT, 3 }); // mask
505- if (prim_params .inputs .size () >= 5 )
494+ if (params .inputs .size () >= 5 )
506495 kernel.params .arguments .push_back ({ArgumentDescriptor::Types::INPUT, 4 }); // Scale
507496
508497 kernel.params .arguments .push_back ({ArgumentDescriptor::Types::SCALAR, 0 }); // D
@@ -530,23 +519,44 @@ KernelsData SDPAKernelMicro::GetKernelsData(const Params& params) const {
530519 shim_options.useTileOps = true ;
531520 shim_options.decorator = " kq" ;
532521
533- kd. kernels [ 0 ] .code .kernelString ->jit += generateShim (gemm_kq, micro::HostLanguage::OpenCL_C, shim_options);
522+ kernel .code .kernelString ->jit += generateShim (gemm_kq, micro::HostLanguage::OpenCL_C, shim_options);
534523
535524 shim_options.microkernelID ++;
536525 shim_options.decorator = " vs" ;
537- kd. kernels [ 0 ] .code .kernelString ->jit += generateShim (gemm_vs, micro::HostLanguage::OpenCL_C, shim_options);
526+ kernel .code .kernelString ->jit += generateShim (gemm_vs, micro::HostLanguage::OpenCL_C, shim_options);
538527
539528 if (gemm_kq.grfMin > 128 || gemm_vs.grfMin > 128 )
540- kd. kernels [ 0 ] .code .kernelString ->options += " -cl-intel-256-GRF-per-thread" ;
529+ kernel .code .kernelString ->options += " -cl-intel-256-GRF-per-thread" ;
541530
542531 std::string extra_options = " -Dcl_intel_dot_accumulate" ;
543532 extra_options += " -Dcl_intel_global_float_atomic" ;
544533 extra_options += " -Dcl_intel_subgroup_matrix_multiply_accumulate" ;
545534 extra_options += " -Dcl_intel_subgroup_split_matrix_multiply_accumulate" ;
546- kd. kernels [ 0 ] .code .kernelString ->options += extra_options;
535+ kernel .code .kernelString ->options += extra_options;
547536
548- kd.kernels [0 ].code .kernelString ->batch_compilation = false ;
549- kd.kernels [0 ].code .kernelString ->has_microkernels = true ;
537+ kernel.code .kernelString ->batch_compilation = false ;
538+ kernel.code .kernelString ->has_microkernels = true ;
539+
540+ return kernel;
541+ }
542+
543+ KernelsData SDPAKernelMicro::GetKernelsData (const Params& params) const {
544+ const size_t num_kernels = 2 ;
545+ KernelData kd = KernelData::Default<sdpa_params>(params, num_kernels);
546+ const auto & prim_params = dynamic_cast <const sdpa_params&>(params);
547+
548+ if (!Validate (params)) {
549+ return {};
550+ }
551+
552+ gemms_kq.resize (2 );
553+ gemms_vs.resize (2 );
554+
555+ for (size_t i = 0 ; i < num_kernels; i++) {
556+ kd.kernels [i] = get_kernel_data (prim_params, gemms_kq[i], gemms_vs[i], i == prefill_id);
557+ }
558+
559+ GetUpdateDispatchDataFunc (kd);
550560
551561 std::cerr << prim_params.layerID << " use micro_sdpa!\n " ;
552562 return { kd };
@@ -555,20 +565,14 @@ KernelsData SDPAKernelMicro::GetKernelsData(const Params& params) const {
555565void SDPAKernelMicro::GetUpdateDispatchDataFunc (KernelData& kd) const {
556566 kd.update_dispatch_data_func = [this ](const Params& params, KernelData& kernel_data) {
557567 const auto & prim_params = static_cast <const sdpa_params&>(params);
558- auto dispatchData = SetDefault (prim_params);
559- OPENVINO_ASSERT (kernel_data.kernels .size () == 1 , " [GPU] Invalid kernels size for update dispatch data func" );
560- kernel_data.kernels [0 ].params .workGroups .global = dispatchData.gws ;
561- kernel_data.kernels [0 ].params .workGroups .local = dispatchData.lws ;
562- kernel_data.kernels [0 ].skip_execution = KernelData::SkipKernelExecution (prim_params);
563-
564- auto head_size = prim_params.conf .head_size ;
565-
566568 const auto & Q = prim_params.inputs [0 ];
567569 const auto & K = prim_params.inputs [1 ];
568570
569571 const auto n_queries = get_seq_length (Q, prim_params.input0_order );
570572 const auto n_keys = get_seq_length (K, prim_params.input1_order );
571573
574+ auto head_size = prim_params.conf .head_size ;
575+
572576 ScalarDescriptor s_d;
573577 s_d.t = ScalarDescriptor::Types::INT32;
574578 s_d.v .s32 = static_cast <uint32_t >(head_size);
@@ -581,11 +585,24 @@ void SDPAKernelMicro::GetUpdateDispatchDataFunc(KernelData& kd) const {
581585 s_q.t = ScalarDescriptor::Types::INT32;
582586 s_q.v .s32 = static_cast <uint32_t >(n_queries.v );
583587
588+ const bool is_prefill = true ;// n_queries.v > 1;
589+
590+ OPENVINO_ASSERT (kernel_data.kernels .size () == 2 , " [GPU] Invalid kernels size for update dispatch data func" );
591+
592+ size_t target_kernel = is_prefill ? prefill_id : generate_id;
593+
594+ kernel_data.kernels [prefill_id].skip_execution = true ;
595+ kernel_data.kernels [generate_id].skip_execution = true ;
596+
597+ auto dispatchData = SetDefault (prim_params, gemms_kq[target_kernel], gemms_vs[target_kernel]);
598+ kernel_data.kernels [target_kernel].params .workGroups .global = dispatchData.gws ;
599+ kernel_data.kernels [target_kernel].params .workGroups .local = dispatchData.lws ;
600+ kernel_data.kernels [target_kernel].skip_execution = KernelData::SkipKernelExecution (prim_params);
584601
585- kernel_data.kernels [0 ].params .scalars .clear ();
586- kernel_data.kernels [0 ].params .scalars .push_back (s_d);
587- kernel_data.kernels [0 ].params .scalars .push_back (s_k);
588- kernel_data.kernels [0 ].params .scalars .push_back (s_q);
602+ kernel_data.kernels [target_kernel ].params .scalars .clear ();
603+ kernel_data.kernels [target_kernel ].params .scalars .push_back (s_d);
604+ kernel_data.kernels [target_kernel ].params .scalars .push_back (s_k);
605+ kernel_data.kernels [target_kernel ].params .scalars .push_back (s_q);
589606 };
590607}
591608
0 commit comments