@@ -334,7 +334,7 @@ bool SDPAKernelMicro::Validate(const Params& p) const {
334334 return true ;
335335}
336336
337- JitConstants SDPAKernelMicro::GetJitConstants (const sdpa_params& params, micro::Package& gemm_kq, micro::Package& gemm_vs) const {
337+ JitConstants SDPAKernelMicro::GetJitConstants (const sdpa_params& params, const micro::Package& gemm_kq, const micro::Package& gemm_vs) const {
338338 auto jit = MakeBaseParamsJitConstants (params);
339339 const auto & prim_params = dynamic_cast <const sdpa_params&>(params);
340340
@@ -452,7 +452,7 @@ JitConstants SDPAKernelMicro::GetJitConstants(const sdpa_params& params, micro::
452452 return jit;
453453}
454454
455- CommonDispatchData SDPAKernelMicro::SetDefault (const sdpa_params& params, micro::Package& gemm_kq, micro::Package& gemm_vs) const {
455+ CommonDispatchData SDPAKernelMicro::SetDefault (const sdpa_params& params, const micro::Package& gemm_kq, const micro::Package& gemm_vs) const {
456456 CommonDispatchData dispatch_data;
457457
458458 auto wg_tile_q = gemm_kq.getSetting (" wg_tile_n" );
@@ -468,12 +468,14 @@ CommonDispatchData SDPAKernelMicro::SetDefault(const sdpa_params& params, micro:
468468 return dispatch_data;
469469}
470470
471- clKernelData SDPAKernelMicro::get_kernel_data (const sdpa_params& params, micro::Package& gemm_kq, micro::Package& gemm_vs, bool is_prefill) const {
471+ clKernelData SDPAKernelMicro::get_kernel_data (const sdpa_params& params, bool is_prefill) const {
472472 auto name = kernelName + (is_prefill ? " _prefill" : " _generate" );
473- init_microkernels (params, gemm_kq, gemm_vs, is_prefill);
474- auto dispatch_data = SetDefault (params, gemm_kq, gemm_vs);
473+
474+ std::vector<micro::Package> gemms (2 ); // KQ and VS
475+ init_microkernels (params, gemms[kq_id], gemms[vs_id], is_prefill);
476+ auto dispatch_data = SetDefault (params, gemms[kq_id], gemms[vs_id]);
475477 auto entry_point = GetEntryPoint (name, params.layerID , params);
476- auto jit = CreateJit (name, GetJitConstants (params, gemm_kq, gemm_vs ), entry_point);
478+ auto jit = CreateJit (name, GetJitConstants (params, gemms[kq_id], gemms[vs_id] ), entry_point);
477479 clKernelData kernel;
478480
479481 FillCLKernelData (kernel, dispatch_data, params.engineInfo , kernelName, jit, entry_point,
@@ -519,13 +521,13 @@ clKernelData SDPAKernelMicro::get_kernel_data(const sdpa_params& params, micro::
519521 shim_options.useTileOps = true ;
520522 shim_options.decorator = " kq" ;
521523
522- kernel.code .kernelString ->jit += generateShim (gemm_kq , micro::HostLanguage::OpenCL_C, shim_options);
524+ kernel.code .kernelString ->jit += generateShim (gemms[kq_id] , micro::HostLanguage::OpenCL_C, shim_options);
523525
524526 shim_options.microkernelID ++;
525527 shim_options.decorator = " vs" ;
526- kernel.code .kernelString ->jit += generateShim (gemm_vs , micro::HostLanguage::OpenCL_C, shim_options);
528+ kernel.code .kernelString ->jit += generateShim (gemms[vs_id] , micro::HostLanguage::OpenCL_C, shim_options);
527529
528- if (gemm_kq .grfMin > 128 || gemm_vs .grfMin > 128 )
530+ if (gemms[kq_id] .grfMin > 128 || gemms[vs_id] .grfMin > 128 )
529531 kernel.code .kernelString ->options += " -cl-intel-256-GRF-per-thread" ;
530532
531533 std::string extra_options = " -Dcl_intel_dot_accumulate" ;
@@ -537,6 +539,10 @@ clKernelData SDPAKernelMicro::get_kernel_data(const sdpa_params& params, micro::
537539 kernel.code .kernelString ->batch_compilation = false ;
538540 kernel.code .kernelString ->has_microkernels = true ;
539541
542+ for (auto & p : gemms) {
543+ kernel.micro_kernels .push_back (std::make_shared<micro::MicroKernelPackage>(p));
544+ }
545+
540546 return kernel;
541547}
542548
@@ -549,11 +555,8 @@ KernelsData SDPAKernelMicro::GetKernelsData(const Params& params) const {
549555 return {};
550556 }
551557
552- gemms_kq.resize (2 );
553- gemms_vs.resize (2 );
554-
555558 for (size_t i = 0 ; i < num_kernels; i++) {
556- kd.kernels [i] = get_kernel_data (prim_params, gemms_kq[i], gemms_vs[i], i == prefill_id);
559+ kd.kernels [i] = get_kernel_data (prim_params, i == prefill_id);
557560 }
558561
559562 GetUpdateDispatchDataFunc (kd);
@@ -594,7 +597,8 @@ void SDPAKernelMicro::GetUpdateDispatchDataFunc(KernelData& kd) const {
594597 kernel_data.kernels [prefill_id].skip_execution = true ;
595598 kernel_data.kernels [generate_id].skip_execution = true ;
596599
597- auto dispatchData = SetDefault (prim_params, gemms_kq[target_kernel], gemms_vs[target_kernel]);
600+ const auto & gemms = kernel_data.kernels [target_kernel].micro_kernels ;
601+ auto dispatchData = SetDefault (prim_params, gemms[kq_id]->p , gemms[vs_id]->p );
598602 kernel_data.kernels [target_kernel].params .workGroups .global = dispatchData.gws ;
599603 kernel_data.kernels [target_kernel].params .workGroups .local = dispatchData.lws ;
600604 kernel_data.kernels [target_kernel].skip_execution = KernelData::SkipKernelExecution (prim_params);
0 commit comments