diff --git a/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_rmsnorm.cpp b/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_rmsnorm.cpp new file mode 100644 index 0000000000..5656b4a918 --- /dev/null +++ b/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_rmsnorm.cpp @@ -0,0 +1,454 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_relu.h" +#include "cutlass/tensor_view.h" +#include "cutlass/coord.h" + +#include "sycl_common.hpp" +#include "helper.h" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// +template +inline +bool isclose(T a, T b, float atol, float rtol) { + return std::abs((float)a - (float)b) <= atol + rtol * std::abs((float)b); +} + +// Command line options parsing +struct Options { + + bool help; + bool error; + + int m, n, k, l, iterations; + float alpha, beta, eps; + + Options(): + help(false), + error(false), + m(5120), n(4096), k(4096), l(1), iterations(100), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 512); + cmd.get_cmd_line_argument("n", n, 512); + cmd.get_cmd_line_argument("k", k, 64); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("eps", eps, 1e-5f); + cmd.get_cmd_line_argument("iterations", iterations, 1); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "PVC GEMM Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Gemm +> +struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + using LayoutW = cutlass::layout::RowMajor; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + using ElementW = typename Gemm::ElementA; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 10; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + cutlass::DeviceAllocation block_W; + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + cutlass::TensorRef ref_W(block_W.get(), LayoutW::packed({1, N})); + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + syclcompat::wait(); + + ElementOutput *ptr_refD = + (ElementOutput *)std::malloc(M * N * L * sizeof(ElementOutput)); + syclcompat::memcpy(ptr_refD, block_ref_D.get(), + M * N * L * sizeof(ElementOutput)); + syclcompat::wait(); + ElementW *ptr_wgt = + (ElementW *)std::malloc(N * L * sizeof(ElementW)); + syclcompat::memcpy(ptr_wgt, block_W.get(), + N * L * sizeof(ElementW)); + syclcompat::wait(); + + constexpr float eps = 1e-5; + // rowwise rmsnorm + for (int l = 0; l < L; l++) { + for (int m = 0; m < M; m++) { + float pow2_sum = (float)0; + for (int n = 0; n < N; n++) { + pow2_sum += pow(ptr_refD[l * M * N + m * N + n], 2); + } + float rms = 1.0f / sqrt(pow2_sum / (float)N + eps); + + for (int n = 0; n < N; n++) { + ptr_refD[l * M * N + m * N + n] = ptr_refD[l * M * N + m * N + n] * rms * (float)ptr_wgt[n]; + } + } + } + + ElementOutput *ptr_D = + (ElementOutput *)std::malloc((size_t)M * N * L * sizeof(ElementOutput)); + syclcompat::memcpy(ptr_D, block_D.get(), + (size_t)M * N * L * sizeof(ElementOutput)); + syclcompat::wait(); + + uint32_t err_cnt = 0; + + float atol = 1e-3; + float rtol = 1e-3; + for (int b = 0; b < L; b++) { + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + int idx = b * M * N + m * N + n; + auto expect = ptr_refD[idx]; + auto val = ptr_D[idx]; + + if (not (std::isinf(ptr_D[idx]) || std::isnan(ptr_D[idx]))) { + if (not isclose(val, expect, atol, rtol)) { + printf("(%d,%d,%d): host: %f and device: %f ratio: %f\n", b, n, m, expect, val, val / expect); + err_cnt++; + } + } else { + printf("(%d,%d,%d): host: %f and device: %f\n", b, n, m, expect, val); + err_cnt++; + } + } + } + } + + std::free(ptr_refD); + std::free(ptr_D); + std::cout << "err count: " << err_cnt + << ", pass rate: " << 100 - (100 * err_cnt / (M * N * L)) << "%" + << std::endl; + return err_cnt == 0; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(M * K * L); + block_B.reset(K * N * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_ref_D.reset(M * N * L); + block_W.reset(N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + initialize_block(block_W, seed + 2020); + } + + cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + using EpilogueArguments = typename Gemm::GemmKernel::EpilogueArguments; + EpilogueArguments epilogue_arguments{ + {options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}; + epilogue_arguments.thread.output_ptr = block_D.get(); + epilogue_arguments.thread.weight_ptr = block_W.get(); + epilogue_arguments.thread.eps = options.eps; + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + epilogue_arguments, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + gemm_op.can_implement(arguments); + + gemm_op.initialize(arguments, workspace.get()); + + // Run the GEMM + gemm_op.run(); + + syclcompat::wait(); + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + if (!passed) return cutlass::Status::kErrorInternal; + + if (passed && options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + gemm_op.run(); + } + syclcompat::wait(); + double io = + options.l * + (options.m * options.k * sizeof(ElementA) + options.k * options.n * sizeof(ElementB) + + options.m * options.n * sizeof(ElementOutput)) * + 1e-9; + float cute_time = timer.seconds() / options.iterations; + double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]GB/s, [%4.3f]TF/s, [%6.4f]ms\n", io / cute_time, tflops/cute_time, cute_time*1000); + } + + return cutlass::Status::kSuccess; + } + +}; + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A + using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B + using ElementWeight = bfloat16_t; // <- data type for elements in rmsnorm weight + using ElementOutput = float; // <- data type of elements in output matrix D + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using GmemTiledCopyA = XE_2D_U16x8x16_LD_N; + using GmemTiledCopyB = XE_2D_U16x16x16_LD_V; + + // Workgroup-level tile + using TileShape = Shape<_32, _512, _32>; + + using TiledMma = + typename TiledMMAHelper, Layout, + Layout, Stride<_16, _1, _0>>>::TiledMMA; + + using EpilogueTile = Shape<_16, _32>; + constexpr int PipelineStages = 3; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + + using EpilogueOp = cutlass::epilogue::fusion::LinCombRMSNormRow; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + void, + void, void>; + +// Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementInputA, + cutlass::gemm::TagToStrideA_t, + ElementInputB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + CUTLASS_CHECK(runner.run(options, hw_info)); + + return 0; +} diff --git a/examples/sycl/05_pvc_gemm_with_epilogues/CMakeLists.txt b/examples/sycl/05_pvc_gemm_with_epilogues/CMakeLists.txt index f8a49df335..f00c412288 100644 --- a/examples/sycl/05_pvc_gemm_with_epilogues/CMakeLists.txt +++ b/examples/sycl/05_pvc_gemm_with_epilogues/CMakeLists.txt @@ -81,3 +81,10 @@ cutlass_example_add_executable( TEST_COMMAND_OPTIONS TEST_BATCHES ) + +cutlass_example_add_executable( + 05_pvc_gemm_with_epilogue_rmsnorm + 05_pvc_gemm_with_epilogue_rmsnorm.cpp + TEST_COMMAND_OPTIONS + TEST_BATCHES +) diff --git a/include/cutlass/epilogue/fusion/operations.hpp b/include/cutlass/epilogue/fusion/operations.hpp index f37eb5b00a..9cd2434575 100644 --- a/include/cutlass/epilogue/fusion/operations.hpp +++ b/include/cutlass/epilogue/fusion/operations.hpp @@ -159,6 +159,20 @@ struct LinCombSoftmaxRow : LinearCombination { }; +// D = rmsnorm((alpha * acc + beta * C)) +template< + class ElementWeight_, + class ElementOutput_, + class ElementCompute_, + class CopyOpR2G_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombRMSNormRow + : LinearCombination { +}; + // D = alpha * acc + beta * C + per-row bias template< class ElementOutput_, diff --git a/include/cutlass/epilogue/fusion/xe_callbacks.hpp b/include/cutlass/epilogue/fusion/xe_callbacks.hpp index 9b616cab1b..e264a0bcb7 100644 --- a/include/cutlass/epilogue/fusion/xe_callbacks.hpp +++ b/include/cutlass/epilogue/fusion/xe_callbacks.hpp @@ -48,6 +48,7 @@ #include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" #include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" #include "cutlass/epilogue/fusion/xe_visitor_softmax.hpp" +#include "cutlass/epilogue/fusion/xe_visitor_rmsnorm.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -240,6 +241,88 @@ struct FusionCallbacks< using Impl::Impl; }; +// D = rmsnorm(alpha * acc + beta * C) +template< + // int FragmentSize, + class CtaTileShapeMNK, + class EpilogueTile, + class ElementWeight, + class ElementOutput, + class ElementCompute, + class CopyOpR2G, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using XeLinCombRMSNormRow = + Sm90EVT, // rmsnorm(beta * C + (alpha * acc)) + Sm90LinearCombination // beta * C + (alpha * acc) + >; + +template < + // int FragmentSize, + class ElementWeight_, + class ElementOutput_, + class ElementCompute_, + class ElementSource_, + class ElementScalar_, + class CopyOpR2G_, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::IntelXeXMX16, + fusion::LinCombRMSNormRow, + CtaTileShapeMNK, + EpilogueTile +> : XeLinCombRMSNormRow { + + using ElementWeight = ElementWeight_; + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementSource = ElementSource_; + using ElementScalar = ElementScalar_; + using Impl = XeLinCombRMSNormRow::type, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, CopyOpR2G_, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinCombRMSNormRow; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementOutput* output_ptr = nullptr; + + using StrideWeight = Stride<_1, _0, int64_t>; + ElementWeight const* weight_ptr = nullptr; + float eps = 1e-5; + StrideWeight dWeight = {}; + + operator typename Impl::Arguments() const { + return + { // unary op: activation(beta * C + (alpha * acc)) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + {output_ptr, weight_ptr, eps} // unary args: activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + template< class StrideAux, class CopyOpG2R, diff --git a/include/cutlass/epilogue/fusion/xe_visitor_rmsnorm.hpp b/include/cutlass/epilogue/fusion/xe_visitor_rmsnorm.hpp new file mode 100644 index 0000000000..fb2c91a791 --- /dev/null +++ b/include/cutlass/epilogue/fusion/xe_visitor_rmsnorm.hpp @@ -0,0 +1,281 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree RMSNorm fusion operation for the Intel PVC epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "xe_visitor_softmax.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class CtaTileShapeMNK, + class EpilogueTile, + class ElementWeight, + class ElementOutput, + class ElementCompute, + class CopyOpR2G, + FloatRoundStyle RoundStyle +> +struct XeRMSNormRowReduction +{ +public: + static constexpr int FragmentSize = 8; + static constexpr auto Tile_M = get<0>(CtaTileShapeMNK{}); + static constexpr auto Tile_N = get<1>(CtaTileShapeMNK{}); + static constexpr auto Epi_M = get<0>(EpilogueTile{}); + static constexpr auto Epi_N = get<1>(EpilogueTile{}); + static constexpr auto Sg_M = Tile_M / Epi_M; + static constexpr auto Sg_N = Tile_N / Epi_N; + static constexpr auto Sg_Nums = Sg_M * Sg_N; + + using Trait_Output = Copy_Traits; + using XE_Copy_output = decltype(make_tiled_copy(Copy_Atom{} + .with(static_cast(nullptr),int32_t(0), int32_t(0)), + Layout>>{}, + make_layout(make_shape(get<0>(typename Trait_Output::BlockShape{}), + get<1>(typename Trait_Output::BlockShape{}) / Int{})))); + + struct SharedStorage { }; + + struct Arguments { + ElementOutput* ptr_output; + ElementWeight const*ptr_weight; + const float eps; + // StrideOutput dOutput; + }; + + struct Params { + XE_Copy_output xe_store_output; + ElementWeight const *weight; + float eps; + int inner_dim; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + XE_Copy_output output = make_tiled_copy(Copy_Atom, ElementOutput>{}.with( + args.ptr_output, M, N), + Layout>>{}, + make_layout(make_shape(get<0>(typename XE_Copy_output::BlockShape{}), + get<1>(typename XE_Copy_output::BlockShape{}) / Int{}))); + return {output, args.ptr_weight, args.eps, N}; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + auto [M, N, K, L] = problem_shape; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + // Cross CTA reduction is not possible because there is no guarantee that all CTAs run + // concurrently. + // Cross epilogue tile reduction is possible, but re-visiting and applying reduction + // to accumulators is only possible for the current epilogue tile. + auto [epi_M, epi_N] = EpilogueTile{}; + return N <= tile_N; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_HOST_DEVICE + XeRMSNormRowReduction() { } + + CUTLASS_HOST_DEVICE + XeRMSNormRowReduction(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + + CUTLASS_DEVICE + ConsumerStoreCallbacks(RTensor&& res_tensor, CoordTensor&& coord, Params const& params) + : res_tensor(cute::forward(res_tensor)), + coord(cute::forward(coord)), + params(params) {} + + RTensor res_tensor; + CoordTensor coord; + Params const& params; + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + + return frg_acc; + } + + template + CUTLASS_DEVICE void + reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { + const float eps = params.eps; + auto sg = syclcompat::get_nd_item<1>().get_sub_group(); + auto group = syclcompat::get_nd_item<1>().get_group()[0]; + auto group_id = group; + auto sg_group_id = sg.get_group_id(); + auto sg_local_id = sg.get_local_id()[0]; + if(is_last_iteration) { + for(int epi_v = 0; epi_v < visit_results(0).size(); epi_v++) { + res_tensor(epi_v, epi_m, epi_n) = visit_results(0)[epi_v]; + } + + constexpr auto vec_size = min(Epi_M, Sg_N); + constexpr auto vec_folds = Epi_M / vec_size; + auto smem = syclcompat::local_mem(); + Tensor stensor = make_tensor(make_smem_ptr(smem), make_shape(Int{}, Int{}, Int{})); + auto wgt_ptr=params.weight; + Tensor res = + make_tensor(static_cast(res_tensor).data(), + make_shape(Int{}, Int{}, Int{})); + // square + Tensor pow2_buff = make_tensor_like(res); + CUTLASS_PRAGMA_UNROLL + for (int loop = 0; loop < vec_folds; loop++) { + auto loop_t = res(_, loop, _); + auto pow2_t = pow2_buff(_, loop, _); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Epi_N / IntelXeXMX16::SubgroupSize; i++) { + auto x_vec = loop_t(_, i); + auto p2_vec = pow2_t(_, i); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < vec_size; j++) { + p2_vec(j) = x_vec(j) * x_vec(j); + } + } + } + int gx = syclcompat::global_id::x() % 256; + int gy = syclcompat::global_id::y(); + auto gid = gx / 16 * 32 + gx % 16; + if (cute::thread0()) { + printf("threadx %d blockx %d blockdimx %d\n", ThreadIdxX(), BlockIdxX(), BlockDimX()); + } + CUTLASS_PRAGMA_UNROLL + for (int loop = 0; loop < vec_folds; loop++) { + auto loop_t = res(_, loop, _); + auto pow2_t = pow2_buff(_, loop, _); + Tensor group_sum = make_tensor(make_shape(Int{})); + float rev_dim = 1 / (float)params.inner_dim; + group_reduce_sum(stensor, pow2_t, group_sum); + Tensor rms = make_tensor(make_shape(Int{})); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < vec_size; ++i) { + rms(i) = pow(group_sum(i) * rev_dim + eps, -0.5); + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Epi_N / IntelXeXMX16::SubgroupSize; i++) { + const float wgt_per_col = (float)wgt_ptr[gid + i * IntelXeXMX16::SubgroupSize]; + auto rmsnorm_vec = loop_t(_, i); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < vec_size; j++) { + rmsnorm_vec(j) = rmsnorm_vec(j) * rms(j) * wgt_per_col; + } + } + } + + copy(params.xe_store_output, res_tensor, coord); + } + else { + for(int epi_v = 0; epi_v < visit_results(0).size(); epi_v++) { + res_tensor(epi_v, epi_m, epi_n) = visit_results(0)[epi_v]; + } + } + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + using MmaAtomShape = typename decltype(args.tiled_mma)::AtomShape_MNK; + static constexpr int FragsM = get<0>(EpilogueTile{}) / get<0>(MmaAtomShape()); // A frags per sub_group + static constexpr int FragsN = get<1>(EpilogueTile{}) / get<1>(MmaAtomShape()); // B frags per sub_group + Tensor res = make_tensor(Shape, Int, Int>{}); + + auto [sg_m_coord, sg_n_coord, k_coord, l_offset] = args.tile_coord_mnkl; + auto [M, N, K, L] = args.problem_shape_mnkl; + Tensor mAux_mnl = cute::get_xe_tensor(make_shape(M,N,L)); + // Tiling is done differently than in epilogue as we get in coordinates of subgroup in kernel + Tensor gAux = local_tile(mAux_mnl, select<0,1>(EpilogueTile{}), make_coord(sg_m_coord,sg_n_coord,l_offset)); + Tensor tCgAux = args.tiled_copy.get_thread_slice(args.thread_idx).partition_D(gAux); + return ConsumerStoreCallbacks( + cute::move(res), + cute::move(tCgAux), + params); + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +/////////////////////////////////////////////////////////////////////////////////////////////////