From 985f87a1fec7c4493853bf3ee00140ee8853f60b Mon Sep 17 00:00:00 2001 From: Yuankun Shi Date: Fri, 9 May 2025 09:59:20 -0700 Subject: [PATCH 1/6] first splitk test --- .../05_pvc_gemm_with_epilogue_splitk.cpp | 458 ++++++++++++++++++ .../cutlass/epilogue/fusion/operations.hpp | 13 + .../cutlass/epilogue/fusion/xe_callbacks.hpp | 74 +++ .../epilogue/fusion/xe_visitor_splitk.hpp | 233 +++++++++ 4 files changed, 778 insertions(+) create mode 100644 examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_splitk.cpp create mode 100644 include/cutlass/epilogue/fusion/xe_visitor_splitk.hpp diff --git a/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_splitk.cpp b/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_splitk.cpp new file mode 100644 index 0000000000..2164bcd7e9 --- /dev/null +++ b/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_splitk.cpp @@ -0,0 +1,458 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_relu.h" +#include "cutlass/tensor_view.h" +#include "cutlass/coord.h" + +#include "sycl_common.hpp" +#include "helper.h" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// +template +inline +bool isclose(T a, T b, float atol, float rtol) { + return std::abs((float)a - (float)b) <= atol + rtol * std::abs((float)b); +} +// Command line options parsing +struct Options { + + bool help; + bool error; + + int m, n, k, l, iterations; + float alpha, beta; + + Options(): + help(false), + error(false), + m(5120), n(4096), k(4096), l(1), iterations(100), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 512); + cmd.get_cmd_line_argument("n", n, 24576); + cmd.get_cmd_line_argument("k", k, 1536); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "PVC GEMM Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Gemm +> +struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + static constexpr int NUM_HEAD = 128; + static constexpr int NOPE_DIM = 128; + static constexpr int ROPE_DIM = 64; + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_D1; + cutlass::DeviceAllocation block_D2; + cutlass::DeviceAllocation block_ref_D; + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + syclcompat::wait(); + // 256x128x64 + ElementOutput *ptr1 = + (ElementOutput *)std::malloc(M * NUM_HEAD * NOPE_DIM * L * sizeof(ElementOutput)); + // 256x128x128 + ElementOutput *ptr2 = + (ElementOutput *)std::malloc(M * NUM_HEAD * ROPE_DIM * L * sizeof(ElementOutput)); + ElementOutput *ptr= + (ElementOutput *)std::malloc(M * N * L * sizeof(ElementOutput)); + syclcompat::memcpy(ptr, block_ref_D.get(), + M * N * L * sizeof(ElementOutput)); + syclcompat::wait(); + + for (int l = 0; l < L; l++) { + for (int i = 0; i < M; i++) { + for (int j = 0; j < NUM_HEAD; j++) { + for (int k = 0; k < NOPE_DIM + ROPE_DIM; ++k) { + if (k < NOPE_DIM) { + ptr1[l * M * NUM_HEAD * NOPE_DIM + i * NUM_HEAD * NOPE_DIM + j * NOPE_DIM + k] = ptr[l * M * N + i * N + j * (NOPE_DIM + ROPE_DIM) + k]; + } else { + ptr2[l * M * NUM_HEAD * ROPE_DIM + i * NUM_HEAD * ROPE_DIM + j * ROPE_DIM + k - NOPE_DIM] = ptr[l * M * N + i * N + j * (NOPE_DIM + ROPE_DIM) + k]; + } + } + } + } + } + + syclcompat::memcpy(block_ref_D.get(), ptr, + M * N * L * sizeof(ElementOutput)); + syclcompat::wait(); + + ElementOutput *ptr_refD = + (ElementOutput *)std::malloc((size_t)M * N * L * sizeof(ElementOutput)); + syclcompat::memcpy(ptr_refD, block_D.get(), + (size_t)M * N * L * sizeof(ElementOutput)); + syclcompat::wait(); + + uint32_t err_cnt = 0; + float atol = 1e-4; + float rtol = 1e-4; + for (int b = 0; b < L; b++) { + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + int idx = b * M * N + i * N + j; + auto expect = ptr[idx]; + auto val = ptr_refD[idx]; + + if (not (std::isinf(val) || std::isnan(val))) { + if (not isclose(val, expect, atol, rtol)) { + std::cout << "(" << b << ", " << i << ", " << j + << "): " << "host: " << ptr[idx] + << " and device: " << ptr_refD[idx] << std::endl; + err_cnt++; + } + } else { + std::cout << "(" << b << ", " << i << ", " << j + << "): " << "host: " << expect << " and device: " << val + << std::endl; + err_cnt++; + } + } + } + } + + std::free(ptr_refD); + std::free(ptr); + std::cout << "err count: " << err_cnt + << ", pass rate: " << 100 - (100 * err_cnt / (M * N * L)) << "%" + << std::endl; + return err_cnt == 0; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(M * K * L); + block_B.reset(K * N * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_D1.reset(M * NUM_HEAD * NOPE_DIM * L); + block_D2.reset(M * NUM_HEAD * ROPE_DIM * L); + block_ref_D.reset(M * N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + using EpilogueArguments = typename Gemm::GemmKernel::EpilogueArguments; + EpilogueArguments epilogue_arguments{ + {options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}; + epilogue_arguments.thread.output_ptr = block_D.get(); + epilogue_arguments.thread.output_ptr1 = block_D1.get(); + epilogue_arguments.thread.output_ptr2 = block_D2.get(); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + epilogue_arguments, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + gemm_op.can_implement(arguments); + + gemm_op.initialize(arguments, workspace.get()); + + // Run the GEMM + gemm_op.run(); + + syclcompat::wait(); + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + if (!passed) return cutlass::Status::kErrorInternal; + + if (passed && options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + gemm_op.run(); + } + syclcompat::wait(); + double io = + options.l * + (options.m * options.k * sizeof(ElementA) + options.k * options.n * sizeof(ElementB) + + options.m * options.n * sizeof(ElementOutput)) * + 1e-9; + float cute_time = timer.seconds() / options.iterations; + double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]GB/s, [%4.3f]TF/s, [%6.4f]ms\n", io / cute_time, tflops/cute_time, cute_time*1000); + } + + return cutlass::Status::kSuccess; + } + +}; + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A + using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using GmemTiledCopyA = XE_2D_U16x8x16_LD_N; + using GmemTiledCopyB = XE_2D_U16x16x16_LD_V; + + // Workgroup-level tile + using TileShape = Shape<_32, _512, _32>; + + using TiledMma = + typename TiledMMAHelper, Layout, + Layout, Stride<_16, _1, _0>>>::TiledMMA; + + using EpilogueTile = Shape<_16, _32>; + constexpr int PipelineStages = 3; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVC; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue; + + using EpilogueOp = cutlass::epilogue::fusion::LinCombSplitK; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + void, + void, void>; + +// Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementInputA, + cutlass::gemm::TagToStrideA_t, + ElementInputB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + CUTLASS_CHECK(runner.run(options, hw_info)); + + return 0; +} diff --git a/include/cutlass/epilogue/fusion/operations.hpp b/include/cutlass/epilogue/fusion/operations.hpp index f37eb5b00a..9f23609682 100644 --- a/include/cutlass/epilogue/fusion/operations.hpp +++ b/include/cutlass/epilogue/fusion/operations.hpp @@ -146,6 +146,19 @@ struct LinCombTopKSoftmaxCol : LinearCombination { }; +// D = softmax(alpha * acc + beta * C) +template< + class ElementOutput_, + class ElementCompute_, + class CopyOpR2G_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombSplitK + : LinearCombination { +}; + // D = softmax(alpha * acc + beta * C) template< class ElementOutput_, diff --git a/include/cutlass/epilogue/fusion/xe_callbacks.hpp b/include/cutlass/epilogue/fusion/xe_callbacks.hpp index 0d199fd383..da1beeec3e 100644 --- a/include/cutlass/epilogue/fusion/xe_callbacks.hpp +++ b/include/cutlass/epilogue/fusion/xe_callbacks.hpp @@ -48,6 +48,7 @@ #include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" #include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" #include "cutlass/epilogue/fusion/xe_visitor_softmax.hpp" +#include "cutlass/epilogue/fusion/xe_visitor_splitk.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -169,6 +170,79 @@ struct FusionCallbacks< using Impl::Impl; }; +// D = splitk(alpha * acc + beta * C) +template< + // int FragmentSize, + class CtaTileShapeMNK, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class CopyOpR2G, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using XeLinCombSplitK = + Sm90EVT, // splitk(beta * C + (alpha * acc)) + Sm90LinearCombination // beta * C + (alpha * acc) + >; + +template < + // int FragmentSize, + class ElementOutput_, + class ElementCompute_, + class ElementSource_, + class ElementScalar_, + class CopyOpR2G_, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::IntelPVCEpilogue, + fusion::LinCombSplitK, + CtaTileShapeMNK, + EpilogueTile +> : XeLinCombSplitK { + + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementSource = ElementSource_; + using ElementScalar = ElementScalar_; + using Impl = XeLinCombSplitK::type, ElementCompute, CopyOpR2G_, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinCombSplitK; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementOutput* output_ptr = nullptr; + ElementOutput *output_ptr1 = nullptr; + ElementOutput *output_ptr2 = nullptr; + + operator typename Impl::Arguments() const { + return + { // unary op: activation(beta * C + (alpha * acc)) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + {output_ptr, output_ptr1, output_ptr2} // unary args: activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + // D = softmax(alpha * acc + beta * C) template< // int FragmentSize, diff --git a/include/cutlass/epilogue/fusion/xe_visitor_splitk.hpp b/include/cutlass/epilogue/fusion/xe_visitor_splitk.hpp new file mode 100644 index 0000000000..04bc9fac9a --- /dev/null +++ b/include/cutlass/epilogue/fusion/xe_visitor_splitk.hpp @@ -0,0 +1,233 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree Softmax fusion operation for the Intel PVC epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class CtaTileShapeMNK, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class CopyOpR2G, + FloatRoundStyle RoundStyle +> +struct XeSplitK +{ +public: + static constexpr int FragmentSize = 8; + static constexpr auto Tile_M = get<0>(CtaTileShapeMNK{}); + static constexpr auto Tile_N = get<1>(CtaTileShapeMNK{}); + static constexpr auto Epi_M = get<0>(EpilogueTile{}); + static constexpr auto Epi_N = get<1>(EpilogueTile{}); + static constexpr auto Sg_M = Tile_M / Epi_M; + static constexpr auto Sg_N = Tile_N / Epi_N; + static constexpr auto Sg_Nums = Sg_M * Sg_N; + + using Trait_Output = Copy_Traits; + using XE_Copy_output = decltype(make_tiled_copy(Copy_Atom{} + .with(static_cast(nullptr),int32_t(0), int32_t(0)), + Layout>>{}, + make_layout(make_shape(get<0>(typename Trait_Output::BlockShape{}), + get<1>(typename Trait_Output::BlockShape{}) / Int{})))); + + struct SharedStorage { }; + + struct Arguments { + ElementOutput* ptr_output; + ElementOutput* ptr_output1; + ElementOutput* ptr_output2; + // StrideOutput dOutput; + }; + + struct Params { + XE_Copy_output xe_store_output; + ElementOutput *ptr_output1; + ElementOutput *ptr_output2; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + XE_Copy_output output = make_tiled_copy(Copy_Atom, ElementOutput>{}.with( + args.ptr_output, M, N), + Layout>>{}, + make_layout(make_shape(get<0>(typename XE_Copy_output::BlockShape{}), + get<1>(typename XE_Copy_output::BlockShape{}) / Int{}))); + + return {output, args.ptr_output1, args.ptr_output2}; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + auto [M, N, K, L] = problem_shape; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + // Cross CTA reduction is not possible because there is no guarantee that all CTAs run + // concurrently. + // Cross epilogue tile reduction is possible, but re-visiting and applying reduction + // to accumulators is only possible for the current epilogue tile. + auto [epi_M, epi_N] = EpilogueTile{}; + return N <= tile_N; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_HOST_DEVICE + XeSplitK() { } + + CUTLASS_HOST_DEVICE + XeSplitK(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + + CUTLASS_DEVICE + ConsumerStoreCallbacks(RTensor&& res_tensor, CoordTensor&& coord, Params const& params) + : res_tensor(cute::forward(res_tensor)), + coord(cute::forward(coord)), + params(params) {} + + RTensor res_tensor; + CoordTensor coord; + Params const& params; + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + + return frg_acc; + } + + template + CUTLASS_DEVICE void + reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { + if(is_last_iteration) { + for(int epi_v = 0; epi_v < visit_results(0).size(); epi_v++) { + res_tensor(epi_v, epi_m, epi_n) = visit_results(0)[epi_v]; + } + + constexpr auto vec_size = min(Epi_M, Sg_N); + constexpr auto vec_folds = Epi_M / vec_size; + + auto smem = syclcompat::local_mem(); + Tensor stensor = make_tensor(make_smem_ptr(smem), make_shape(Int{}, Int{}, Int{})); + + Tensor res = + make_tensor(static_cast(res_tensor).data(), + make_shape(Int{}, Int{}, Int{})); + + + copy(params.xe_store_output, res_tensor, coord); + } + else { + for(int epi_v = 0; epi_v < visit_results(0).size(); epi_v++) { + res_tensor(epi_v, epi_m, epi_n) = visit_results(0)[epi_v]; + } + } + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + using MmaAtomShape = typename decltype(args.tiled_mma)::AtomShape_MNK; + static constexpr int FragsM = get<0>(EpilogueTile{}) / get<0>(MmaAtomShape()); // A frags per sub_group + static constexpr int FragsN = get<1>(EpilogueTile{}) / get<1>(MmaAtomShape()); // B frags per sub_group + Tensor res = make_tensor(Shape, Int, Int>{}); + + auto [sg_m_coord, sg_n_coord, k_coord, l_offset] = args.tile_coord_mnkl; + auto [M, N, K, L] = args.problem_shape_mnkl; + Tensor mAux_mnl = cute::get_pvc_tensor(make_shape(M,N,L)); + // Tiling is done differently than in epilogue as we get in coordinates of subgroup in kernel + Tensor gAux = local_tile(mAux_mnl, select<0,1>(EpilogueTile{}), make_coord(sg_m_coord,sg_n_coord,l_offset)); + Tensor tCgAux = args.tiled_copy.get_thread_slice(args.thread_idx).partition_D(gAux); + + return ConsumerStoreCallbacks( + cute::move(res), + cute::move(tCgAux), + params); + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// From c5fcabd04c13d4a9a8123525f088b8f288890c0e Mon Sep 17 00:00:00 2001 From: Yuankun Shi Date: Fri, 9 May 2025 13:47:08 -0700 Subject: [PATCH 2/6] fix cmake --- examples/sycl/05_pvc_gemm_with_epilogues/CMakeLists.txt | 5 +++++ include/cutlass/epilogue/fusion/operations.hpp | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/sycl/05_pvc_gemm_with_epilogues/CMakeLists.txt b/examples/sycl/05_pvc_gemm_with_epilogues/CMakeLists.txt index 7ec1f66411..220233601d 100644 --- a/examples/sycl/05_pvc_gemm_with_epilogues/CMakeLists.txt +++ b/examples/sycl/05_pvc_gemm_with_epilogues/CMakeLists.txt @@ -54,6 +54,11 @@ cutlass_example_add_executable( 05_pvc_gemm_with_epilogue_softmax.cpp ) +cutlass_example_add_executable( + 05_pvc_gemm_with_epilogue_splitk + 05_pvc_gemm_with_epilogue_splitk.cpp +) + cutlass_example_add_executable( 05_pvc_gemm_with_per_row_bias 05_pvc_gemm_with_per_row_bias.cpp diff --git a/include/cutlass/epilogue/fusion/operations.hpp b/include/cutlass/epilogue/fusion/operations.hpp index 9f23609682..5cef2f59c0 100644 --- a/include/cutlass/epilogue/fusion/operations.hpp +++ b/include/cutlass/epilogue/fusion/operations.hpp @@ -146,7 +146,7 @@ struct LinCombTopKSoftmaxCol : LinearCombination { }; -// D = softmax(alpha * acc + beta * C) +// D = splitk(alpha * acc + beta * C) template< class ElementOutput_, class ElementCompute_, From 1f76456dbb9542100bf1473cc2c1e420d58ee2e2 Mon Sep 17 00:00:00 2001 From: Yuankun Shi Date: Tue, 13 May 2025 11:28:21 -0700 Subject: [PATCH 3/6] update splitk fusion --- .../05_pvc_gemm_with_epilogue_splitk.cpp | 178 ++++++++++++++---- .../epilogue/fusion/xe_visitor_splitk.hpp | 157 ++++++++++++++- 2 files changed, 291 insertions(+), 44 deletions(-) diff --git a/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_splitk.cpp b/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_splitk.cpp index 2164bcd7e9..52abf785b4 100644 --- a/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_splitk.cpp +++ b/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_splitk.cpp @@ -25,7 +25,7 @@ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ @@ -53,13 +53,39 @@ #include "helper.h" using namespace cute; - /////////////////////////////////////////////////////////////////////////////////////////////////// template inline bool isclose(T a, T b, float atol, float rtol) { return std::abs((float)a - (float)b) <= atol + rtol * std::abs((float)b); } + +template +inline +void verify_acc(T expect, T val, const float atol, const float rtol, int &err_cnt) { +} + +template +inline +void random_fill(T *src, int seed, size_t N, float max, float min) { + if constexpr(std::is_same_v || std::is_same_v || std::is_same_v) { + std::random_device rd; + std::mt19937 gen(seed); + std::uniform_real_distribution dis(min, max); + T *buff = + (T *)std::malloc(N * sizeof(T)); + + for (size_t i = 0; i < N; ++i) { + buff[i] = (T)(dis(gen)); + } + syclcompat::memcpy(src, buff, N * sizeof(T)); + syclcompat::wait(); + std::free(buff); + } else { + assert(0 & "Not supported dtype"); + } +} + // Command line options parsing struct Options { @@ -91,7 +117,7 @@ struct Options { cmd.get_cmd_line_argument("l", l, 1); cmd.get_cmd_line_argument("alpha", alpha, 1.f); cmd.get_cmd_line_argument("beta", beta, 0.f); - cmd.get_cmd_line_argument("iterations", iterations, 100); + cmd.get_cmd_line_argument("iterations", iterations, 0); } /// Prints the usage statement. @@ -195,70 +221,150 @@ struct ExampleRunner { syclcompat::wait(); // 256x128x64 - ElementOutput *ptr1 = + ElementOutput *ptr_ref_D1 = (ElementOutput *)std::malloc(M * NUM_HEAD * NOPE_DIM * L * sizeof(ElementOutput)); // 256x128x128 - ElementOutput *ptr2 = + ElementOutput *ptr_ref_D2 = (ElementOutput *)std::malloc(M * NUM_HEAD * ROPE_DIM * L * sizeof(ElementOutput)); - ElementOutput *ptr= + ElementOutput *ptr_ref_D = (ElementOutput *)std::malloc(M * N * L * sizeof(ElementOutput)); - syclcompat::memcpy(ptr, block_ref_D.get(), + syclcompat::memcpy(ptr_ref_D, block_ref_D.get(), M * N * L * sizeof(ElementOutput)); syclcompat::wait(); - + printf("res:"); + for (int l = 0; l < L; ++l) { + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + if ((n % 16) == 0) + printf("\n(%03d:%04d): ", m, n); + printf("% 7.3f ", ptr_ref_D[l * M * N + m * N + n]); + } + } + } + printf("\n"); for (int l = 0; l < L; l++) { for (int i = 0; i < M; i++) { for (int j = 0; j < NUM_HEAD; j++) { for (int k = 0; k < NOPE_DIM + ROPE_DIM; ++k) { if (k < NOPE_DIM) { - ptr1[l * M * NUM_HEAD * NOPE_DIM + i * NUM_HEAD * NOPE_DIM + j * NOPE_DIM + k] = ptr[l * M * N + i * N + j * (NOPE_DIM + ROPE_DIM) + k]; + ptr_ref_D1[l * M * NUM_HEAD * NOPE_DIM + i * NUM_HEAD * NOPE_DIM + j * NOPE_DIM + k] = + ptr_ref_D[l * M * N + i * N + j * (NOPE_DIM + ROPE_DIM) + k]; } else { - ptr2[l * M * NUM_HEAD * ROPE_DIM + i * NUM_HEAD * ROPE_DIM + j * ROPE_DIM + k - NOPE_DIM] = ptr[l * M * N + i * N + j * (NOPE_DIM + ROPE_DIM) + k]; + ptr_ref_D2[l * M * NUM_HEAD * ROPE_DIM + i * NUM_HEAD * ROPE_DIM + j * ROPE_DIM + k - NOPE_DIM] = + ptr_ref_D[l * M * N + i * N + j * (NOPE_DIM + ROPE_DIM) + k]; } } } } } - syclcompat::memcpy(block_ref_D.get(), ptr, - M * N * L * sizeof(ElementOutput)); - syclcompat::wait(); - - ElementOutput *ptr_refD = + ElementOutput *ptr_test_D = (ElementOutput *)std::malloc((size_t)M * N * L * sizeof(ElementOutput)); - syclcompat::memcpy(ptr_refD, block_D.get(), + syclcompat::memcpy(ptr_test_D, block_D.get(), (size_t)M * N * L * sizeof(ElementOutput)); + + // 256x128x64 + ElementOutput *ptr_test_D1 = + (ElementOutput *)std::malloc(M * NUM_HEAD * NOPE_DIM * L * sizeof(ElementOutput)); + // 256x128x128 + ElementOutput *ptr_test_D2 = + (ElementOutput *)std::malloc(M * NUM_HEAD * ROPE_DIM * L * sizeof(ElementOutput)); + syclcompat::memcpy(ptr_test_D1, block_D1.get(), + M * NUM_HEAD * NOPE_DIM * L * sizeof(ElementOutput)); + syclcompat::memcpy(ptr_test_D2, block_D2.get(), + M * NUM_HEAD * ROPE_DIM * L * sizeof(ElementOutput)); + syclcompat::wait(); + + syclcompat::wait(); uint32_t err_cnt = 0; - float atol = 1e-4; - float rtol = 1e-4; + constexpr float atol = 1e-4; + constexpr float rtol = 1e-4; for (int b = 0; b < L; b++) { for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { int idx = b * M * N + i * N + j; - auto expect = ptr[idx]; - auto val = ptr_refD[idx]; - + auto expect = ptr_ref_D[idx]; + auto val = ptr_test_D[idx]; if (not (std::isinf(val) || std::isnan(val))) { - if (not isclose(val, expect, atol, rtol)) { - std::cout << "(" << b << ", " << i << ", " << j - << "): " << "host: " << ptr[idx] - << " and device: " << ptr_refD[idx] << std::endl; - err_cnt++; - } + if (not isclose(val, expect, atol, rtol)) { + std::cout << "(" << b << ", " << i << ", " << j + << "): " << "host: " << expect + << " and device: " << val << std::endl; + err_cnt++; + } } else { - std::cout << "(" << b << ", " << i << ", " << j - << "): " << "host: " << expect << " and device: " << val - << std::endl; - err_cnt++; + std::cout << "(" << b << ", " << i << ", " << j + << "): " << "host: " << expect << " and device: " << val + << std::endl; + err_cnt++; } } } } + constexpr int NUM_HEAD = 8; + constexpr int NOPE_DIM = 128; + constexpr int ROPE_DIM = 64; + printf("CHECK d1:\n"); + // check d1 + for (int b = 0; b < L; b++) { + for (int i = 0; i < M; i++) { + for (int j = 0; j < NUM_HEAD; j++) { + for (int k = 0; k < NOPE_DIM; ++k) { + int idx = b * M * NUM_HEAD * NOPE_DIM + i * NUM_HEAD * NOPE_DIM + j * NOPE_DIM + k; + auto expect = ptr_ref_D1[idx]; + auto val = ptr_test_D1[idx]; + if (not (std::isinf(val) || std::isnan(val))) { + if (not isclose(val, expect, atol, rtol)) { + std::cout << "(" << b << ", " << i << ", " << j << ", " << k + << "): " << "host: " << expect + << " and device: " << val << std::endl; + err_cnt++; + } + } else { + std::cout << "(" << b << ", " << i << ", " << j + << "): " << "host: " << expect << " and device: " << val + << std::endl; + err_cnt++; + } + } + } + } + } + printf("CHECK d2:\n"); + // check d2 + for (int b = 0; b < L; b++) { + for (int i = 0; i < M; i++) { + for (int j = 0; j < NUM_HEAD; j++) { + for (int k = 0; k < ROPE_DIM; ++k) { + int idx = b * M * NUM_HEAD * ROPE_DIM + i * NUM_HEAD * ROPE_DIM + j * ROPE_DIM + k; + auto expect = ptr_ref_D2[idx]; + auto val = ptr_test_D2[idx]; + if (not (std::isinf(val) || std::isnan(val))) { + if (not isclose(val, expect, atol, rtol)) { + std::cout << "(" << b << ", " << i << ", " << j << ", " << k + << "): " << "host: " << expect + << " and device: " << val << std::endl; + err_cnt++; + } + } else { + std::cout << "(" << b << ", " << i << ", " << j + << "): " << "host: " << expect << " and device: " << val + << std::endl; + err_cnt++; + } + } + } + } + } - std::free(ptr_refD); - std::free(ptr); + std::free(ptr_test_D); + std::free(ptr_test_D1); + std::free(ptr_test_D2); + std::free(ptr_ref_D); + std::free(ptr_ref_D1); + std::free(ptr_ref_D2); std::cout << "err count: " << err_cnt << ", pass rate: " << 100 - (100 * err_cnt / (M * N * L)) << "%" << std::endl; @@ -283,9 +389,9 @@ struct ExampleRunner { block_D2.reset(M * NUM_HEAD * ROPE_DIM * L); block_ref_D.reset(M * N * L); - initialize_block(block_A, seed + 2023); - initialize_block(block_B, seed + 2022); - initialize_block(block_C, seed + 2021); + random_fill(block_A.get(), seed + 2023, block_A.size(), -1.0, 1.0); + random_fill(block_B.get(), seed + 2022, block_B.size(), -1.0, 1.0); + random_fill(block_C.get(), seed + 2021, block_C.size(), -1.0, 1.0); } cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { diff --git a/include/cutlass/epilogue/fusion/xe_visitor_splitk.hpp b/include/cutlass/epilogue/fusion/xe_visitor_splitk.hpp index 04bc9fac9a..0902fb9b24 100644 --- a/include/cutlass/epilogue/fusion/xe_visitor_splitk.hpp +++ b/include/cutlass/epilogue/fusion/xe_visitor_splitk.hpp @@ -83,6 +83,8 @@ struct XeSplitK XE_Copy_output xe_store_output; ElementOutput *ptr_output1; ElementOutput *ptr_output2; + XE_Copy_output xe_store_output1; + // XE_Copy_output xe_store_output2; }; template @@ -95,8 +97,22 @@ struct XeSplitK Layout>>{}, make_layout(make_shape(get<0>(typename XE_Copy_output::BlockShape{}), get<1>(typename XE_Copy_output::BlockShape{}) / Int{}))); - - return {output, args.ptr_output1, args.ptr_output2}; + constexpr int NUM_HEAD = 8; + constexpr int NOPE_DIM = 128; + constexpr int ROPE_DIM = 64; + + XE_Copy_output output1 = make_tiled_copy( + Copy_Atom, ElementOutput>{}.with( + args.ptr_output1, M, NUM_HEAD *NOPE_DIM), + Layout>>{}, + make_layout(make_shape(get<0>(typename XE_Copy_output::BlockShape{}), + get<1>(typename XE_Copy_output::BlockShape{}) / Int{}))); + + // Tensor mAux_mnl1 = cute::get_pvc_tensor(make_shape(M, 8 * 128, L)); + // Tensor gAux1 = local_tile(mAux_mnl1, select<0,1>(EpilogueTile{}), make_coord(sg_m_coord,sg_n_coord,l_offset)); + // Tensor tCgAux1 = args.tiled_copy.get_thread_slice(args.thread_idx).partition_D(gAux1); + + return {output, args.ptr_output1, args.ptr_output2, output1}; } template @@ -149,18 +165,58 @@ struct XeSplitK get_producer_load_callbacks(ProducerLoadArgs const& args) { return EmptyProducerLoadCallbacks{}; } - + template + CUTLASS_DEVICE static void + print_tensor_orig(VTensor &t) { + print(t); + auto t_shape = t.shape(); + auto t_stride = t.stride(); + auto t_rank = rank(t_shape); + auto total = t.size(); + for (auto i = 0; i < total; ++i) { + if ((i % get<0>(t_shape)) == 0) + print("\n%2d: ", i / get<0>(t_shape)); + print("%7.3f ", t[i]); + } + print("\n"); + } + template + CUTLASS_DEVICE static void + print_tensor(VTensor &t) { + print(t); + for (int k = 0; k < 2; ++k) { + for (int j = 0; j < 2; ++j) { + printf("\n%2d: ", (k * 2 + j) * 8); + for (int i = 0; i < 8; ++i) { + printf("%7.3f ", t(i, j, k)); + } + } + } + printf("\n"); + } + template + CUTLASS_DEVICE static void + print_coord(CoordTensor &t) { + for (int i = 0; i < t.kRank; ++i) { + printf("%d ", t.idx[i]); + } + printf("\n"); + } template struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { CUTLASS_DEVICE - ConsumerStoreCallbacks(RTensor&& res_tensor, CoordTensor&& coord, Params const& params) + ConsumerStoreCallbacks(RTensor&& res_tensor, CoordTensor&& coord, CoordTensor&& coord1, CoordTensor&& coord2, Params const& params) : res_tensor(cute::forward(res_tensor)), coord(cute::forward(coord)), + coord1(cute::forward(coord1)), + coord2(cute::forward(coord2)), params(params) {} RTensor res_tensor; CoordTensor coord; + CoordTensor coord1; + CoordTensor coord2; Params const& params; template CUTLASS_DEVICE auto @@ -177,7 +233,6 @@ struct XeSplitK for(int epi_v = 0; epi_v < visit_results(0).size(); epi_v++) { res_tensor(epi_v, epi_m, epi_n) = visit_results(0)[epi_v]; } - constexpr auto vec_size = min(Epi_M, Sg_N); constexpr auto vec_folds = Epi_M / vec_size; @@ -187,8 +242,78 @@ struct XeSplitK Tensor res = make_tensor(static_cast(res_tensor).data(), make_shape(Int{}, Int{}, Int{})); - - + constexpr int tid = 96; + constexpr int bid = 0; + auto m_coord = get<0>(coord[0]); + auto n_coord = get<1>(coord[0]); + auto l_coord = get<2>(coord[0]); + if (cute::thread(tid, bid)) { + printf("res: "); + print(res_tensor); + printf("\n"); + print_tensor_orig(res_tensor); + print_tensor(res_tensor); + printf("\nxe_store:\n"); + print(params.xe_store_output); + printf("\ncoord:\n"); + print(coord); + printf("m %d n %d l %d\n", m_coord, n_coord, l_coord); + } + constexpr int NUM_HEAD = 8; + constexpr int NOPE_DIM = 128; + constexpr int ROPE_DIM = 64; + constexpr int ROW_DIM = NOPE_DIM + ROPE_DIM; + constexpr int ROW_1_DIM = NUM_HEAD * NOPE_DIM; + constexpr int ROW_2_DIM = NUM_HEAD * ROPE_DIM; + constexpr auto thr_stride = 8 * 2; // vec_size*vec_folds + int tidx = ThreadIdxX(); + int bidx = BlockIdxX(); + int col = n_coord; + int idx_2 = col % ROW_DIM; + int idx_1 = col / ROW_DIM; + int idx_0_tid = m_coord; + if (cute::thread(tid, bid)) { + printf("tid %d bid %d m %d n %d blockdim %d tidx 0st %d 1st %d 2nd %d\n", tidx, bidx, m_coord, n_coord, BlockDimX(), idx_0_tid, idx_1, idx_2); + } + // if ((tidx >= tid) && (tidx < tid + 16) && (bidx == bid)) { + if (idx_2 < NOPE_DIM) { + // static_assert(cute::is_same_v); + unsigned long n_coord1 = idx_1 * NOPE_DIM + idx_2; + auto coord1 = make_coord(m_coord, n_coord1, l_coord); + coord[0] = coord1; + if (cute::thread(tid, bid)) { + printf("m %d n %d l %d\n", m_coord, n_coord1, l_coord); + print(coord1); + printf("\n"); + print(coord); + printf("\n"); + } + copy(params.xe_store_output1, res_tensor, coord); + // copy to first tensor + // for (int j = 0; j < 2; ++j) { // vec_folds + // for (int i = 0; i < 8; ++i) { // vec_size + // int idx_0 = j * 8 + i + idx_0_tid; + // for (int k = 0; k < 2; ++k) { // Epi_N / IntelPVCEpilogue::SubgroupSize + // params.ptr_output1[idx_0 * ROW_1_DIM + idx_1 * NOPE_DIM + (k * thr_stride + idx_2)] = res_tensor(i, j, k); + // // ---idx_0-----------------idx_1-------------------idx_2------ + // } + // } + // } + } else { + int rope_idx_2 = idx_2 - NOPE_DIM; + // copy to second tensor + for (int j = 0; j < 2; ++j) { // vec_folds + for (int i = 0; i < 8; ++i) { // vec_size + int idx_0 = j * 8 + i + idx_0_tid; + for (int k = 0; k < 2; ++k) { // Epi_N / IntelPVCEpilogue::SubgroupSize + params.ptr_output2[idx_0 * ROW_2_DIM + idx_1 * ROPE_DIM + (k * thr_stride + rope_idx_2)] = res_tensor(i, j, k); + // ---idx_0-----------------idx_1-------------------idx_2------ + } + } + } + } + // } + // sync_fn(); copy(params.xe_store_output, res_tensor, coord); } else { @@ -217,9 +342,25 @@ struct XeSplitK Tensor gAux = local_tile(mAux_mnl, select<0,1>(EpilogueTile{}), make_coord(sg_m_coord,sg_n_coord,l_offset)); Tensor tCgAux = args.tiled_copy.get_thread_slice(args.thread_idx).partition_D(gAux); + unsigned long sg_n_coord_d1 = sg_n_coord % 6; + unsigned long sg_n_coord_d0 = sg_n_coord / 6; + + unsigned long sg_n_coord_1 = sg_n_coord_d0 * 4 + sg_n_coord_d1; + unsigned long sg_n_coord_2 = sg_n_coord_d0 * 4 + sg_n_coord_d1 - 4; + Tensor mAux_mnl1 = cute::get_pvc_tensor(make_shape(M,8 * 128,L)); + // Tiling is done differently than in epilogue as we get in coordinates of subgroup in kernel + Tensor gAux1 = local_tile(mAux_mnl1, select<0,1>(EpilogueTile{}), make_coord(sg_m_coord,sg_n_coord_1,l_offset)); + Tensor tCgAux1 = args.tiled_copy.get_thread_slice(args.thread_idx).partition_D(gAux1); + + Tensor mAux_mnl2 = cute::get_pvc_tensor(make_shape(M,8 * 64,L)); + // Tiling is done differently than in epilogue as we get in coordinates of subgroup in kernel + Tensor gAux2 = local_tile(mAux_mnl2, select<0,1>(EpilogueTile{}), make_coord(sg_m_coord,sg_n_coord_2,l_offset)); + Tensor tCgAux2 = args.tiled_copy.get_thread_slice(args.thread_idx).partition_D(gAux2); return ConsumerStoreCallbacks( - cute::move(res), + cute::move(res), cute::move(tCgAux), + cute::move(tCgAux1), + cute::move(tCgAux2), params); } From 61920a825c6005990bd4caaa257d502b1fdb30e8 Mon Sep 17 00:00:00 2001 From: Yuankun Shi Date: Tue, 13 May 2025 19:25:48 -0700 Subject: [PATCH 4/6] pass small case accuracy test --- .../05_pvc_gemm_with_epilogue_splitk.cpp | 57 ++++---- .../cutlass/epilogue/fusion/xe_callbacks.hpp | 6 +- .../epilogue/fusion/xe_visitor_splitk.hpp | 123 +++++++++--------- 3 files changed, 97 insertions(+), 89 deletions(-) diff --git a/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_splitk.cpp b/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_splitk.cpp index 52abf785b4..4ac2b00e2a 100644 --- a/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_splitk.cpp +++ b/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_splitk.cpp @@ -92,14 +92,15 @@ struct Options { bool help; bool error; - int m, n, k, l, iterations; + int m, n, k, l, num_head, nope_dim, rope_dim, iterations; float alpha, beta; Options(): help(false), error(false), - m(5120), n(4096), k(4096), l(1), iterations(100), - alpha(1.f), beta(0.f) + m(5120), n(4096), k(4096), l(1), + num_head(128), nope_dim(128), rope_dim(64), + iterations(100), alpha(1.f), beta(0.f) { } // Parses the command line @@ -115,6 +116,9 @@ struct Options { cmd.get_cmd_line_argument("n", n, 24576); cmd.get_cmd_line_argument("k", k, 1536); cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("num-head", num_head, 128); + cmd.get_cmd_line_argument("nope-dim", nope_dim, 128); + cmd.get_cmd_line_argument("rope-dim", rope_dim, 64); cmd.get_cmd_line_argument("alpha", alpha, 1.f); cmd.get_cmd_line_argument("beta", beta, 0.f); cmd.get_cmd_line_argument("iterations", iterations, 0); @@ -130,6 +134,9 @@ struct Options { << " --n= Sets the N extent of the GEMM\n" << " --k= Sets the K extent of the GEMM\n" << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --num-head= Sets the num_head for splitk fusion\n" + << " --nope-dim= Sets the nope_dim for splitk fusion\n" + << " --rope-dim= Sets the rope_dim for splitk fusion\n" << " --alpha= Epilogue scalar alpha\n" << " --beta= Epilogue scalar beta\n\n" << " --iterations= Iterations\n\n"; @@ -167,9 +174,6 @@ struct ExampleRunner { using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; - static constexpr int NUM_HEAD = 128; - static constexpr int NOPE_DIM = 128; - static constexpr int ROPE_DIM = 64; // // Data members // @@ -193,9 +197,9 @@ struct ExampleRunner { // Methods // - bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + bool verify(const ProblemShapeType& problem_size, const ProblemShapeType &splitk_size, ElementCompute alpha, ElementCompute beta) { auto [M, N, K, L] = problem_size; - + auto [NUM_HEAD, NOPE_DIM, ROPE_DIM, _] = splitk_size; cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); @@ -231,17 +235,17 @@ struct ExampleRunner { syclcompat::memcpy(ptr_ref_D, block_ref_D.get(), M * N * L * sizeof(ElementOutput)); syclcompat::wait(); - printf("res:"); - for (int l = 0; l < L; ++l) { - for (int m = 0; m < M; ++m) { - for (int n = 0; n < N; ++n) { - if ((n % 16) == 0) - printf("\n(%03d:%04d): ", m, n); - printf("% 7.3f ", ptr_ref_D[l * M * N + m * N + n]); - } - } - } - printf("\n"); + // printf("res:"); + // for (int l = 0; l < L; ++l) { + // for (int m = 0; m < M; ++m) { + // for (int n = 0; n < N; ++n) { + // if ((n % 16) == 0) + // printf("\n(%03d:%04d): ", m, n); + // printf("% 7.3f ", ptr_ref_D[l * M * N + m * N + n]); + // } + // } + // } + // printf("\n"); for (int l = 0; l < L; l++) { for (int i = 0; i < M; i++) { for (int j = 0; j < NUM_HEAD; j++) { @@ -303,9 +307,7 @@ struct ExampleRunner { } } } - constexpr int NUM_HEAD = 8; - constexpr int NOPE_DIM = 128; - constexpr int ROPE_DIM = 64; + printf("CHECK d1:\n"); // check d1 for (int b = 0; b < L; b++) { @@ -372,9 +374,10 @@ struct ExampleRunner { } /// Initialize operands to be used in the GEMM and reference GEMM - void initialize(const ProblemShapeType& problem_size) { + void initialize(const ProblemShapeType& problem_size, const ProblemShapeType &splitk_size) { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); auto [M, N, K, L] = problem_shape_MNKL; + auto [NUM_HEAD, NOPE_DIM, ROPE_DIM, _] = splitk_size; stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); @@ -396,14 +399,18 @@ struct ExampleRunner { cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + ProblemShapeType splitk_size = ProblemShapeType{options.num_head, options.nope_dim, options.rope_dim, 1}; - initialize(problem_size); + initialize(problem_size, splitk_size); using EpilogueArguments = typename Gemm::GemmKernel::EpilogueArguments; EpilogueArguments epilogue_arguments{ {options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}; epilogue_arguments.thread.output_ptr = block_D.get(); epilogue_arguments.thread.output_ptr1 = block_D1.get(); epilogue_arguments.thread.output_ptr2 = block_D2.get(); + epilogue_arguments.thread.NUM_HEAD = options.num_head; + epilogue_arguments.thread.NOPE_DIM = options.nope_dim; + epilogue_arguments.thread.ROPE_DIM = options.rope_dim; typename Gemm::GemmKernel::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, @@ -428,7 +435,7 @@ struct ExampleRunner { syclcompat::wait(); // Verify that the result is correct - bool passed = verify(problem_size, options.alpha, options.beta); + bool passed = verify(problem_size, splitk_size, options.alpha, options.beta); std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; if (!passed) return cutlass::Status::kErrorInternal; diff --git a/include/cutlass/epilogue/fusion/xe_callbacks.hpp b/include/cutlass/epilogue/fusion/xe_callbacks.hpp index da1beeec3e..237f28e2c1 100644 --- a/include/cutlass/epilogue/fusion/xe_callbacks.hpp +++ b/include/cutlass/epilogue/fusion/xe_callbacks.hpp @@ -220,7 +220,9 @@ struct FusionCallbacks< ElementOutput* output_ptr = nullptr; ElementOutput *output_ptr1 = nullptr; ElementOutput *output_ptr2 = nullptr; - + size_t NUM_HEAD = 0; + size_t NOPE_DIM = 0; + size_t ROPE_DIM = 0; operator typename Impl::Arguments() const { return { // unary op: activation(beta * C + (alpha * acc)) @@ -234,7 +236,7 @@ struct FusionCallbacks< }, // end binary op {} // ternary args : multiply_add }, // end ternary op - {output_ptr, output_ptr1, output_ptr2} // unary args: activation + {output_ptr, output_ptr1, output_ptr2, NUM_HEAD, NOPE_DIM, ROPE_DIM} // unary args: activation }; // end unary op } }; diff --git a/include/cutlass/epilogue/fusion/xe_visitor_splitk.hpp b/include/cutlass/epilogue/fusion/xe_visitor_splitk.hpp index 0902fb9b24..970e100fff 100644 --- a/include/cutlass/epilogue/fusion/xe_visitor_splitk.hpp +++ b/include/cutlass/epilogue/fusion/xe_visitor_splitk.hpp @@ -76,6 +76,9 @@ struct XeSplitK ElementOutput* ptr_output; ElementOutput* ptr_output1; ElementOutput* ptr_output2; + size_t NUM_HEAD; + size_t NOPE_DIM; + size_t ROPE_DIM; // StrideOutput dOutput; }; @@ -84,7 +87,10 @@ struct XeSplitK ElementOutput *ptr_output1; ElementOutput *ptr_output2; XE_Copy_output xe_store_output1; - // XE_Copy_output xe_store_output2; + XE_Copy_output xe_store_output2; + size_t NUM_HEAD; + size_t NOPE_DIM; + size_t ROPE_DIM; }; template @@ -97,9 +103,9 @@ struct XeSplitK Layout>>{}, make_layout(make_shape(get<0>(typename XE_Copy_output::BlockShape{}), get<1>(typename XE_Copy_output::BlockShape{}) / Int{}))); - constexpr int NUM_HEAD = 8; - constexpr int NOPE_DIM = 128; - constexpr int ROPE_DIM = 64; + auto NUM_HEAD = args.NUM_HEAD; + auto NOPE_DIM = args.NOPE_DIM; + auto ROPE_DIM = args.ROPE_DIM; XE_Copy_output output1 = make_tiled_copy( Copy_Atom, ElementOutput>{}.with( @@ -108,11 +114,14 @@ struct XeSplitK make_layout(make_shape(get<0>(typename XE_Copy_output::BlockShape{}), get<1>(typename XE_Copy_output::BlockShape{}) / Int{}))); - // Tensor mAux_mnl1 = cute::get_pvc_tensor(make_shape(M, 8 * 128, L)); - // Tensor gAux1 = local_tile(mAux_mnl1, select<0,1>(EpilogueTile{}), make_coord(sg_m_coord,sg_n_coord,l_offset)); - // Tensor tCgAux1 = args.tiled_copy.get_thread_slice(args.thread_idx).partition_D(gAux1); + XE_Copy_output output2 = make_tiled_copy( + Copy_Atom, ElementOutput>{}.with( + args.ptr_output2, M, NUM_HEAD *ROPE_DIM), + Layout>>{}, + make_layout(make_shape(get<0>(typename XE_Copy_output::BlockShape{}), + get<1>(typename XE_Copy_output::BlockShape{}) / Int{}))); - return {output, args.ptr_output1, args.ptr_output2, output1}; + return {output, args.ptr_output1, args.ptr_output2, output1, output2, NUM_HEAD, NOPE_DIM, ROPE_DIM}; } template @@ -217,7 +226,7 @@ struct XeSplitK CoordTensor coord; CoordTensor coord1; CoordTensor coord2; - Params const& params; + Params const& params; template CUTLASS_DEVICE auto visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, @@ -242,53 +251,41 @@ struct XeSplitK Tensor res = make_tensor(static_cast(res_tensor).data(), make_shape(Int{}, Int{}, Int{})); - constexpr int tid = 96; - constexpr int bid = 0; + // constexpr int tid = 160; + // constexpr int bid = 0; auto m_coord = get<0>(coord[0]); auto n_coord = get<1>(coord[0]); auto l_coord = get<2>(coord[0]); - if (cute::thread(tid, bid)) { - printf("res: "); - print(res_tensor); - printf("\n"); - print_tensor_orig(res_tensor); - print_tensor(res_tensor); - printf("\nxe_store:\n"); - print(params.xe_store_output); - printf("\ncoord:\n"); - print(coord); - printf("m %d n %d l %d\n", m_coord, n_coord, l_coord); - } - constexpr int NUM_HEAD = 8; - constexpr int NOPE_DIM = 128; - constexpr int ROPE_DIM = 64; - constexpr int ROW_DIM = NOPE_DIM + ROPE_DIM; - constexpr int ROW_1_DIM = NUM_HEAD * NOPE_DIM; - constexpr int ROW_2_DIM = NUM_HEAD * ROPE_DIM; - constexpr auto thr_stride = 8 * 2; // vec_size*vec_folds - int tidx = ThreadIdxX(); - int bidx = BlockIdxX(); + // if (cute::thread(tid, bid)) { + // printf("res: "); + // print(res_tensor); + // printf("\n"); + // print_tensor_orig(res_tensor); + // print_tensor(res_tensor); + // printf("\nxe_store:\n"); + // print(params.xe_store_output); + // printf("\ncoord:\n"); + // print(coord); + // printf("m %d n %d l %d\n", m_coord, n_coord, l_coord); + // } + auto NUM_HEAD = params.NUM_HEAD; + auto NOPE_DIM = params.NOPE_DIM; + auto ROPE_DIM = params.ROPE_DIM; + auto ROW_DIM = NOPE_DIM + ROPE_DIM; + auto ROW_1_DIM = NUM_HEAD * NOPE_DIM; + auto ROW_2_DIM = NUM_HEAD * ROPE_DIM; + // constexpr auto thr_stride = 8 * 2; // vec_size*vec_folds int col = n_coord; int idx_2 = col % ROW_DIM; - int idx_1 = col / ROW_DIM; - int idx_0_tid = m_coord; - if (cute::thread(tid, bid)) { - printf("tid %d bid %d m %d n %d blockdim %d tidx 0st %d 1st %d 2nd %d\n", tidx, bidx, m_coord, n_coord, BlockDimX(), idx_0_tid, idx_1, idx_2); - } + // int idx_1 = col / ROW_DIM; + // int idx_0_tid = m_coord; + // if (cute::thread(tid, bid)) { + // printf("tid %d bid %d m %d n %d blockdim %d tidx 0st %d 1st %d 2nd %d\n", tidx, bidx, m_coord, n_coord, BlockDimX(), idx_0_tid, idx_1, idx_2); + // } // if ((tidx >= tid) && (tidx < tid + 16) && (bidx == bid)) { if (idx_2 < NOPE_DIM) { // static_assert(cute::is_same_v); - unsigned long n_coord1 = idx_1 * NOPE_DIM + idx_2; - auto coord1 = make_coord(m_coord, n_coord1, l_coord); - coord[0] = coord1; - if (cute::thread(tid, bid)) { - printf("m %d n %d l %d\n", m_coord, n_coord1, l_coord); - print(coord1); - printf("\n"); - print(coord); - printf("\n"); - } - copy(params.xe_store_output1, res_tensor, coord); + copy(params.xe_store_output1, res_tensor, coord1); // copy to first tensor // for (int j = 0; j < 2; ++j) { // vec_folds // for (int i = 0; i < 8; ++i) { // vec_size @@ -300,17 +297,18 @@ struct XeSplitK // } // } } else { - int rope_idx_2 = idx_2 - NOPE_DIM; - // copy to second tensor - for (int j = 0; j < 2; ++j) { // vec_folds - for (int i = 0; i < 8; ++i) { // vec_size - int idx_0 = j * 8 + i + idx_0_tid; - for (int k = 0; k < 2; ++k) { // Epi_N / IntelPVCEpilogue::SubgroupSize - params.ptr_output2[idx_0 * ROW_2_DIM + idx_1 * ROPE_DIM + (k * thr_stride + rope_idx_2)] = res_tensor(i, j, k); - // ---idx_0-----------------idx_1-------------------idx_2------ - } - } - } + // int rope_idx_2 = idx_2 - NOPE_DIM; + // // copy to second tensor + // for (int j = 0; j < 2; ++j) { // vec_folds + // for (int i = 0; i < 8; ++i) { // vec_size + // int idx_0 = j * 8 + i + idx_0_tid; + // for (int k = 0; k < 2; ++k) { // Epi_N / IntelPVCEpilogue::SubgroupSize + // params.ptr_output2[idx_0 * ROW_2_DIM + idx_1 * ROPE_DIM + (k * thr_stride + rope_idx_2)] = res_tensor(i, j, k); + // // ---idx_0-----------------idx_1-------------------idx_2------ + // } + // } + // } + copy(params.xe_store_output2, res_tensor, coord2); } // } // sync_fn(); @@ -342,11 +340,11 @@ struct XeSplitK Tensor gAux = local_tile(mAux_mnl, select<0,1>(EpilogueTile{}), make_coord(sg_m_coord,sg_n_coord,l_offset)); Tensor tCgAux = args.tiled_copy.get_thread_slice(args.thread_idx).partition_D(gAux); - unsigned long sg_n_coord_d1 = sg_n_coord % 6; + unsigned long sg_n_coord_d1 = sg_n_coord % 6; // 192 unsigned long sg_n_coord_d0 = sg_n_coord / 6; - unsigned long sg_n_coord_1 = sg_n_coord_d0 * 4 + sg_n_coord_d1; - unsigned long sg_n_coord_2 = sg_n_coord_d0 * 4 + sg_n_coord_d1 - 4; + unsigned long sg_n_coord_1 = sg_n_coord_d0 * 4 + sg_n_coord_d1; // nope_dim 128 + unsigned long sg_n_coord_2 = sg_n_coord_d0 * 2 + sg_n_coord_d1 - 4; // rope_dim 64 Tensor mAux_mnl1 = cute::get_pvc_tensor(make_shape(M,8 * 128,L)); // Tiling is done differently than in epilogue as we get in coordinates of subgroup in kernel Tensor gAux1 = local_tile(mAux_mnl1, select<0,1>(EpilogueTile{}), make_coord(sg_m_coord,sg_n_coord_1,l_offset)); @@ -356,6 +354,7 @@ struct XeSplitK // Tiling is done differently than in epilogue as we get in coordinates of subgroup in kernel Tensor gAux2 = local_tile(mAux_mnl2, select<0,1>(EpilogueTile{}), make_coord(sg_m_coord,sg_n_coord_2,l_offset)); Tensor tCgAux2 = args.tiled_copy.get_thread_slice(args.thread_idx).partition_D(gAux2); + return ConsumerStoreCallbacks( cute::move(res), cute::move(tCgAux), From b7eb5471a3a47ac10a375e68e1a739f04ecc93bb Mon Sep 17 00:00:00 2001 From: Yuankun Shi Date: Wed, 14 May 2025 16:26:26 -0700 Subject: [PATCH 5/6] remove unnecessary code for splitk fusion --- .../05_pvc_gemm_with_epilogue_splitk.cpp | 38 +------ .../epilogue/fusion/xe_visitor_splitk.hpp | 107 ++---------------- 2 files changed, 11 insertions(+), 134 deletions(-) diff --git a/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_splitk.cpp b/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_splitk.cpp index 4ac2b00e2a..30001f7230 100644 --- a/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_splitk.cpp +++ b/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_splitk.cpp @@ -121,7 +121,7 @@ struct Options { cmd.get_cmd_line_argument("rope-dim", rope_dim, 64); cmd.get_cmd_line_argument("alpha", alpha, 1.f); cmd.get_cmd_line_argument("beta", beta, 0.f); - cmd.get_cmd_line_argument("iterations", iterations, 0); + cmd.get_cmd_line_argument("iterations", iterations, 100); } /// Prints the usage statement. @@ -235,17 +235,6 @@ struct ExampleRunner { syclcompat::memcpy(ptr_ref_D, block_ref_D.get(), M * N * L * sizeof(ElementOutput)); syclcompat::wait(); - // printf("res:"); - // for (int l = 0; l < L; ++l) { - // for (int m = 0; m < M; ++m) { - // for (int n = 0; n < N; ++n) { - // if ((n % 16) == 0) - // printf("\n(%03d:%04d): ", m, n); - // printf("% 7.3f ", ptr_ref_D[l * M * N + m * N + n]); - // } - // } - // } - // printf("\n"); for (int l = 0; l < L; l++) { for (int i = 0; i < M; i++) { for (int j = 0; j < NUM_HEAD; j++) { @@ -285,28 +274,6 @@ struct ExampleRunner { uint32_t err_cnt = 0; constexpr float atol = 1e-4; constexpr float rtol = 1e-4; - for (int b = 0; b < L; b++) { - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - int idx = b * M * N + i * N + j; - auto expect = ptr_ref_D[idx]; - auto val = ptr_test_D[idx]; - if (not (std::isinf(val) || std::isnan(val))) { - if (not isclose(val, expect, atol, rtol)) { - std::cout << "(" << b << ", " << i << ", " << j - << "): " << "host: " << expect - << " and device: " << val << std::endl; - err_cnt++; - } - } else { - std::cout << "(" << b << ", " << i << ", " << j - << "): " << "host: " << expect << " and device: " << val - << std::endl; - err_cnt++; - } - } - } - } printf("CHECK d1:\n"); // check d1 @@ -378,7 +345,8 @@ struct ExampleRunner { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); auto [M, N, K, L] = problem_shape_MNKL; auto [NUM_HEAD, NOPE_DIM, ROPE_DIM, _] = splitk_size; - + assert((NOPE_DIM % 32 == 0) && (NOPE_DIM / 32>0) && "NOPE_DIM should be divisible by 32"); + assert((ROPE_DIM % 32 == 0) && (ROPE_DIM / 32>0) && "ROPE_DIM should be divisible by 32"); stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); diff --git a/include/cutlass/epilogue/fusion/xe_visitor_splitk.hpp b/include/cutlass/epilogue/fusion/xe_visitor_splitk.hpp index 970e100fff..7e24f728b0 100644 --- a/include/cutlass/epilogue/fusion/xe_visitor_splitk.hpp +++ b/include/cutlass/epilogue/fusion/xe_visitor_splitk.hpp @@ -174,46 +174,9 @@ struct XeSplitK get_producer_load_callbacks(ProducerLoadArgs const& args) { return EmptyProducerLoadCallbacks{}; } - template - CUTLASS_DEVICE static void - print_tensor_orig(VTensor &t) { - print(t); - auto t_shape = t.shape(); - auto t_stride = t.stride(); - auto t_rank = rank(t_shape); - auto total = t.size(); - for (auto i = 0; i < total; ++i) { - if ((i % get<0>(t_shape)) == 0) - print("\n%2d: ", i / get<0>(t_shape)); - print("%7.3f ", t[i]); - } - print("\n"); - } - template - CUTLASS_DEVICE static void - print_tensor(VTensor &t) { - print(t); - for (int k = 0; k < 2; ++k) { - for (int j = 0; j < 2; ++j) { - printf("\n%2d: ", (k * 2 + j) * 8); - for (int i = 0; i < 8; ++i) { - printf("%7.3f ", t(i, j, k)); - } - } - } - printf("\n"); - } - template - CUTLASS_DEVICE static void - print_coord(CoordTensor &t) { - for (int i = 0; i < t.kRank; ++i) { - printf("%d ", t.idx[i]); - } - printf("\n"); - } template struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - + CUTLASS_DEVICE ConsumerStoreCallbacks(RTensor&& res_tensor, CoordTensor&& coord, CoordTensor&& coord1, CoordTensor&& coord2, Params const& params) : res_tensor(cute::forward(res_tensor)), @@ -231,7 +194,7 @@ struct XeSplitK CUTLASS_DEVICE auto visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, Array const& frg_input) { - + return frg_acc; } @@ -245,74 +208,20 @@ struct XeSplitK constexpr auto vec_size = min(Epi_M, Sg_N); constexpr auto vec_folds = Epi_M / vec_size; - auto smem = syclcompat::local_mem(); - Tensor stensor = make_tensor(make_smem_ptr(smem), make_shape(Int{}, Int{}, Int{})); - Tensor res = make_tensor(static_cast(res_tensor).data(), make_shape(Int{}, Int{}, Int{})); - // constexpr int tid = 160; - // constexpr int bid = 0; - auto m_coord = get<0>(coord[0]); auto n_coord = get<1>(coord[0]); - auto l_coord = get<2>(coord[0]); - // if (cute::thread(tid, bid)) { - // printf("res: "); - // print(res_tensor); - // printf("\n"); - // print_tensor_orig(res_tensor); - // print_tensor(res_tensor); - // printf("\nxe_store:\n"); - // print(params.xe_store_output); - // printf("\ncoord:\n"); - // print(coord); - // printf("m %d n %d l %d\n", m_coord, n_coord, l_coord); - // } - auto NUM_HEAD = params.NUM_HEAD; auto NOPE_DIM = params.NOPE_DIM; auto ROPE_DIM = params.ROPE_DIM; auto ROW_DIM = NOPE_DIM + ROPE_DIM; - auto ROW_1_DIM = NUM_HEAD * NOPE_DIM; - auto ROW_2_DIM = NUM_HEAD * ROPE_DIM; - // constexpr auto thr_stride = 8 * 2; // vec_size*vec_folds int col = n_coord; int idx_2 = col % ROW_DIM; - // int idx_1 = col / ROW_DIM; - // int idx_0_tid = m_coord; - // if (cute::thread(tid, bid)) { - // printf("tid %d bid %d m %d n %d blockdim %d tidx 0st %d 1st %d 2nd %d\n", tidx, bidx, m_coord, n_coord, BlockDimX(), idx_0_tid, idx_1, idx_2); - // } - // if ((tidx >= tid) && (tidx < tid + 16) && (bidx == bid)) { if (idx_2 < NOPE_DIM) { - // static_assert(cute::is_same_v); copy(params.xe_store_output1, res_tensor, coord1); - // copy to first tensor - // for (int j = 0; j < 2; ++j) { // vec_folds - // for (int i = 0; i < 8; ++i) { // vec_size - // int idx_0 = j * 8 + i + idx_0_tid; - // for (int k = 0; k < 2; ++k) { // Epi_N / IntelPVCEpilogue::SubgroupSize - // params.ptr_output1[idx_0 * ROW_1_DIM + idx_1 * NOPE_DIM + (k * thr_stride + idx_2)] = res_tensor(i, j, k); - // // ---idx_0-----------------idx_1-------------------idx_2------ - // } - // } - // } } else { - // int rope_idx_2 = idx_2 - NOPE_DIM; - // // copy to second tensor - // for (int j = 0; j < 2; ++j) { // vec_folds - // for (int i = 0; i < 8; ++i) { // vec_size - // int idx_0 = j * 8 + i + idx_0_tid; - // for (int k = 0; k < 2; ++k) { // Epi_N / IntelPVCEpilogue::SubgroupSize - // params.ptr_output2[idx_0 * ROW_2_DIM + idx_1 * ROPE_DIM + (k * thr_stride + rope_idx_2)] = res_tensor(i, j, k); - // // ---idx_0-----------------idx_1-------------------idx_2------ - // } - // } - // } copy(params.xe_store_output2, res_tensor, coord2); } - // } - // sync_fn(); - copy(params.xe_store_output, res_tensor, coord); } else { for(int epi_v = 0; epi_v < visit_results(0).size(); epi_v++) { @@ -321,7 +230,7 @@ struct XeSplitK } } }; - + template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy class... Args @@ -339,18 +248,18 @@ struct XeSplitK // Tiling is done differently than in epilogue as we get in coordinates of subgroup in kernel Tensor gAux = local_tile(mAux_mnl, select<0,1>(EpilogueTile{}), make_coord(sg_m_coord,sg_n_coord,l_offset)); Tensor tCgAux = args.tiled_copy.get_thread_slice(args.thread_idx).partition_D(gAux); + unsigned long sg_n_coord_d1 = sg_n_coord % 6; // syk have to hard code here 192 + unsigned long sg_n_coord_d0 = sg_n_coord / 6; // 6*32=192 - unsigned long sg_n_coord_d1 = sg_n_coord % 6; // 192 - unsigned long sg_n_coord_d0 = sg_n_coord / 6; - + auto NUM_HEAD = N / 192; unsigned long sg_n_coord_1 = sg_n_coord_d0 * 4 + sg_n_coord_d1; // nope_dim 128 unsigned long sg_n_coord_2 = sg_n_coord_d0 * 2 + sg_n_coord_d1 - 4; // rope_dim 64 - Tensor mAux_mnl1 = cute::get_pvc_tensor(make_shape(M,8 * 128,L)); + Tensor mAux_mnl1 = cute::get_pvc_tensor(make_shape(M,NUM_HEAD * 128,L)); // Tiling is done differently than in epilogue as we get in coordinates of subgroup in kernel Tensor gAux1 = local_tile(mAux_mnl1, select<0,1>(EpilogueTile{}), make_coord(sg_m_coord,sg_n_coord_1,l_offset)); Tensor tCgAux1 = args.tiled_copy.get_thread_slice(args.thread_idx).partition_D(gAux1); - Tensor mAux_mnl2 = cute::get_pvc_tensor(make_shape(M,8 * 64,L)); + Tensor mAux_mnl2 = cute::get_pvc_tensor(make_shape(M,NUM_HEAD * 64,L)); // Tiling is done differently than in epilogue as we get in coordinates of subgroup in kernel Tensor gAux2 = local_tile(mAux_mnl2, select<0,1>(EpilogueTile{}), make_coord(sg_m_coord,sg_n_coord_2,l_offset)); Tensor tCgAux2 = args.tiled_copy.get_thread_slice(args.thread_idx).partition_D(gAux2); From 8c93306da23db77a9bd678d7787c869df2d0c9aa Mon Sep 17 00:00:00 2001 From: Yuankun Shi Date: Thu, 15 May 2025 15:06:23 -0700 Subject: [PATCH 6/6] update splitk epilogue --- .../epilogue/fusion/xe_visitor_splitk.hpp | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/include/cutlass/epilogue/fusion/xe_visitor_splitk.hpp b/include/cutlass/epilogue/fusion/xe_visitor_splitk.hpp index 7e24f728b0..23517d2eaa 100644 --- a/include/cutlass/epilogue/fusion/xe_visitor_splitk.hpp +++ b/include/cutlass/epilogue/fusion/xe_visitor_splitk.hpp @@ -248,18 +248,23 @@ struct XeSplitK // Tiling is done differently than in epilogue as we get in coordinates of subgroup in kernel Tensor gAux = local_tile(mAux_mnl, select<0,1>(EpilogueTile{}), make_coord(sg_m_coord,sg_n_coord,l_offset)); Tensor tCgAux = args.tiled_copy.get_thread_slice(args.thread_idx).partition_D(gAux); - unsigned long sg_n_coord_d1 = sg_n_coord % 6; // syk have to hard code here 192 - unsigned long sg_n_coord_d0 = sg_n_coord / 6; // 6*32=192 - - auto NUM_HEAD = N / 192; - unsigned long sg_n_coord_1 = sg_n_coord_d0 * 4 + sg_n_coord_d1; // nope_dim 128 - unsigned long sg_n_coord_2 = sg_n_coord_d0 * 2 + sg_n_coord_d1 - 4; // rope_dim 64 - Tensor mAux_mnl1 = cute::get_pvc_tensor(make_shape(M,NUM_HEAD * 128,L)); + auto nope_dim = params.NOPE_DIM; + auto rope_dim = params.ROPE_DIM; + auto inner_dim_sg = (nope_dim + rope_dim) / 32; + auto nope_dim_sg = nope_dim / 32; + auto rope_dim_sg = rope_dim / 32; + unsigned long sg_n_coord_d1 = sg_n_coord % inner_dim_sg; + unsigned long sg_n_coord_d0 = sg_n_coord / inner_dim_sg; + + auto num_head = N / inner_dim_sg; + unsigned long sg_n_coord_1 = sg_n_coord_d0 * nope_dim_sg + sg_n_coord_d1; + unsigned long sg_n_coord_2 = sg_n_coord_d0 * rope_dim_sg + sg_n_coord_d1 - nope_dim_sg; + Tensor mAux_mnl1 = cute::get_pvc_tensor(make_shape(M,num_head * nope_dim,L)); // Tiling is done differently than in epilogue as we get in coordinates of subgroup in kernel Tensor gAux1 = local_tile(mAux_mnl1, select<0,1>(EpilogueTile{}), make_coord(sg_m_coord,sg_n_coord_1,l_offset)); Tensor tCgAux1 = args.tiled_copy.get_thread_slice(args.thread_idx).partition_D(gAux1); - Tensor mAux_mnl2 = cute::get_pvc_tensor(make_shape(M,NUM_HEAD * 64,L)); + Tensor mAux_mnl2 = cute::get_pvc_tensor(make_shape(M,num_head * rope_dim,L)); // Tiling is done differently than in epilogue as we get in coordinates of subgroup in kernel Tensor gAux2 = local_tile(mAux_mnl2, select<0,1>(EpilogueTile{}), make_coord(sg_m_coord,sg_n_coord_2,l_offset)); Tensor tCgAux2 = args.tiled_copy.get_thread_slice(args.thread_idx).partition_D(gAux2);