From be4eb69e6586a69af8395cc275c54d066f8e916e Mon Sep 17 00:00:00 2001 From: Yuankun Shi Date: Wed, 9 Apr 2025 23:36:46 -0700 Subject: [PATCH 1/6] upload draft for gemm rms norm --- .../gemm/collective/xe_mma_rmsnorm.hpp | 308 ++++++++++++++++++ 1 file changed, 308 insertions(+) create mode 100644 include/cutlass/gemm/collective/xe_mma_rmsnorm.hpp diff --git a/include/cutlass/gemm/collective/xe_mma_rmsnorm.hpp b/include/cutlass/gemm/collective/xe_mma_rmsnorm.hpp new file mode 100644 index 0000000000..5c663ec10a --- /dev/null +++ b/include/cutlass/gemm/collective/xe_mma_rmsnorm.hpp @@ -0,0 +1,308 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + /* + * This implements W8A8 GEMM (FP8 weights and activations) using FP16 compute as a workaround, + * since current Intel GPUs (e.g., PVC, BMG) lack native FP8 support. + * The kernel converts FP8 inputs to FP16 on-the-fly and performs GEMM using FP16 MMA. + */ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct CollectiveMma, TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, + GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, GmemTiledCopyB_, SmemLayoutAtomB_, + SmemCopyAtomB_, TransformB_> { + // + // Type Aliases + // + using DispatchPolicy = MainloopIntelW8A8; + using WorkgroupTileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static_assert(platform::is_same::value, "MainloopIntelW8A8 requires that A and B have same type."); + // TODO: support E5M2 + static_assert(std::is_same_v, "ElementA must be fp8 (E4M3)"); + static_assert(std::is_same_v, "ElementB must be fp8 (E4M3)"); + + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + using MmaAtomShape = typename TiledMma::AtomShape_MNK; + + static constexpr auto BLK_M = get<0>(WorkgroupTileShape{}); + static constexpr auto BLK_N = get<1>(WorkgroupTileShape{}); + static constexpr auto BLK_K = get<2>(WorkgroupTileShape{}); + + static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); + + static constexpr auto SG_M = ceil_div(BLK_M, ATOM_M); + static constexpr auto SG_N = ceil_div(BLK_N, ATOM_N); + static constexpr auto SG_K = ceil_div(BLK_K, ATOM_K); + + using SubgroupTileShape = Shape; + static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); + + using CopyThreadShape = Shape<_1, Int>; + + using traits_load_A = Copy_Traits; + using atom_load_A = Copy_Atom; + using val_layout_load_A = decltype(make_layout(shape_div(typename traits_load_A::BlockShape{}, CopyThreadShape{}))); + using Copy_A = decltype(make_tiled_copy(atom_load_A{}, Layout{}, val_layout_load_A{})); + + using traits_load_B = Copy_Traits; + using atom_load_B = Copy_Atom; + using val_layout_load_B = decltype(make_layout(shape_div(typename traits_load_B::BlockShape{}, CopyThreadShape{}))); + using Copy_B = decltype(make_tiled_copy(atom_load_B{}, Layout{}, val_layout_load_B{})); + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + }; + + struct Params { + Copy_A tiled_copy_a; + Copy_B tiled_copy_b; + }; + + // + // Methods + // + + CollectiveMma() = default; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + auto [M,N,K,L] = problem_shape; + + auto mA_mkl = make_tensor(make_gmem_ptr(static_cast(args.ptr_A)), + make_layout(make_shape(M, K, L), args.dA)); + auto mB_nkl = make_tensor(make_gmem_ptr(static_cast(args.ptr_B)), + make_layout(make_shape(N, K, L), args.dB)); + + Copy_A tiled_copy_a{Copy_A{}.with(mA_mkl)}; + Copy_B tiled_copy_b{Copy_B{}.with(mB_nkl)}; + + return Params{tiled_copy_a, tiled_copy_b}; + } + + template + CUTLASS_DEVICE + void convert_E4M3_to_FP16( + Tensor const& in, + Tensor& out) { + + static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); + static_assert(is_rmem::value, "Output tensor for A conversion must come from registers"); + static_assert(cosize_v == cosize_v); + static_assert(size_v == cosize_v); + static_assert(size_v == cosize_v); + + using SrcType = typename EngineIn::value_type; + using DstType = typename EngineOut::value_type; + + static_assert(std::is_same_v, "Expected fp8 (E4M3) input as uint8_t"); + static_assert(std::is_same_v, "Expected fp16 output as half_t"); + + auto const& src = in(_, _, _); + auto const& dst = out(_, _, _); + + SrcType const* pSrc = src.data(); + DstType* pDst = dst.data(); + + constexpr int num_elements = decltype(size(src))::value; + constexpr int vec_size = 16; + // TODO(Codeplay): Move conversion to NumericArrayConverter + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < num_elements / vec_size; ++i) { + // vectorized load + cute::intel::uchar16 src_vec; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < vec_size; ++j) { + src_vec[j] = pSrc[i * vec_size + j]; + } + // vectorized convert fp8 -> fp16 + cute::intel::ushort16 dst_vec = E4M3_to_FP16_vec16(src_vec); + // vectorized store + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < vec_size; ++j) { + reinterpret_cast(pDst)[i * vec_size + j] = dst_vec[j]; + + } + } + } + + template + CUTLASS_DEVICE + void RMSNorm(Tensor const &in, + Tensor const &w, + Tensor &out) { + using SrcType = typename EngineIn::value_type; + using DstType = typename EngineOut::value_type; + + auto const &src = in(_, _, _); + auto const &dst = out(_, _, _); + + SrcType const *pSrc = src.data(); + DstType * pDst = dst.data(); + + constexpr int num_elements = decltype(size(src))::value; + constexpr int vec_size = 8; + + cute::intel::float8 x2_vec = 0.0f; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < num_elements / vec_size; ++i) { + cute::intel::float8 src_vec = &pSrc[i]; + cute::intel::float8 x2_vec += src_vec * src_vec; + } + scratch[loc_id] = x2_vec; + for (int i = work_group_size / 2; i >0; i >>= 1) { + item.barrier(sycl::access::fence_space::local_space); + if (loc_id < i) { + scratch[loc_id] += scratch[loc_id + i]; + } + } + if (loc_id == 0) { + accum_acc[group_id] = rsqrt(sratch[0] / num_elements + eps); + } + float rms = accum_acc[group_id]; + for (int i = 0; i < num_elements / vec_size; ++i) { + cute::intel::float8 src_vec = pSrc[i]; + cute::intel::float8 wgt_vec = pWgt[i]; + pDst[i] = wgt_vec * src_vec * rms; + } + } + // Perform a subgroup-scoped matrix multiply-accumulate + template + CUTLASS_DEVICE void operator()(FrgTensorD &accum, TensorA gA, TensorB gB, FrgTensorC const &src_accum, + KTileIterator k_tile_iter, int k_tile_count, BlkCoord const &blk_coord, int const &K_start, int thread_idx, + Params const &mainloop) { + (void)blk_coord; + static_assert(is_rmem::value, "D tensor must be rmem resident."); + static_assert(is_rmem::value, "C tensor must be rmem resident."); + + auto thr_copy_A = mainloop.tiled_copy_a.get_slice(thread_idx); + auto thr_copy_B = mainloop.tiled_copy_b.get_slice(thread_idx); + + // Instantiate the MMA object and get thread slice + TiledMma tiled_mma; + auto sg = syclcompat::get_nd_item<1>().get_sub_group(); + auto first_thread_in_sg_idx = sg.get_group_linear_id() * DispatchPolicy::SubgroupSize; + auto thr_mma = tiled_mma.get_slice(first_thread_in_sg_idx); + + // Partition + Tensor tCgA = thr_mma.partition_A(gA); + Tensor tCgB = thr_mma.partition_B(gB); + + Tensor tCrA = make_tensor(make_fragment_layout(mainloop.tiled_copy_a, tCgA(_,_,_,0).shape())); + Tensor tCrB = make_tensor(make_fragment_layout(mainloop.tiled_copy_b, tCgB(_,_,_,0).shape())); + + Tensor tCrA_fp16 = make_fragment_like(tCrA); + Tensor tCrB_fp16 = make_fragment_like(tCrB); + + // Retile registers for copies + Tensor tArA = thr_copy_A.retile_D(tCrA); + Tensor tBrB = thr_copy_B.retile_D(tCrB); + + // Retile global tile for copies + Tensor tAgA = thr_copy_A.retile_S(tCgA); + Tensor tBgB = thr_copy_B.retile_S(tCgB); + + // + // Mainloop + // + const auto k_start_idx = crd2idx((*k_tile_iter), make_shape(K_start)); + constexpr int barrier_scope = 2; + + CUTLASS_PRAGMA_UNROLL + for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; k_tile++) { + barrier_arrive(barrier_scope); + + // copy fp8 into uint8 + copy(mainloop.tiled_copy_a, tAgA(_,_,_,k_tile), tArA); + copy(mainloop.tiled_copy_b, tBgB(_,_,_,k_tile), tBrB); + + // compute using fp16 + cute::gemm(tiled_mma, tCrA_fp16, tCrB_fp16, accum); + + barrier_wait(barrier_scope); + } + + } + +}; + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// From f123be4fcb030722c2a9296d56d13a203f18d62a Mon Sep 17 00:00:00 2001 From: Yuankun Shi Date: Wed, 16 Apr 2025 21:44:30 -0700 Subject: [PATCH 2/6] add rmsnorm post ops --- .../05_pvc_gemm_with_epilogue_rmsnorm.cpp | 529 ++++++++++++++++++ .../05_pvc_gemm_with_epilogues/CMakeLists.txt | 7 + .../cutlass/epilogue/fusion/operations.hpp | 14 + .../cutlass/epilogue/fusion/xe_callbacks.hpp | 82 +++ .../epilogue/fusion/xe_visitor_rmsnorm.hpp | 408 ++++++++++++++ .../gemm/collective/xe_mma_rmsnorm.hpp | 308 ---------- 6 files changed, 1040 insertions(+), 308 deletions(-) create mode 100644 examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_rmsnorm.cpp create mode 100644 include/cutlass/epilogue/fusion/xe_visitor_rmsnorm.hpp delete mode 100644 include/cutlass/gemm/collective/xe_mma_rmsnorm.hpp diff --git a/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_rmsnorm.cpp b/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_rmsnorm.cpp new file mode 100644 index 0000000000..f9869dce97 --- /dev/null +++ b/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_rmsnorm.cpp @@ -0,0 +1,529 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_relu.h" +#include "cutlass/tensor_view.h" +#include "cutlass/coord.h" + +#include "sycl_common.hpp" +#include "helper.h" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// +template +inline +bool isclose(T a, T b, float atol, float rtol) { + return std::abs((float)a - (float)b) <= atol + rtol * std::abs((float)b); +} + +// Command line options parsing +struct Options { + + bool help; + bool error; + + int m, n, k, l, iterations; + float alpha, beta; + + Options(): + help(false), + error(false), + m(5120), n(4096), k(4096), l(1), iterations(100), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 512); + cmd.get_cmd_line_argument("n", n, 512); + cmd.get_cmd_line_argument("k", k, 64); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 1); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "PVC GEMM Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Gemm +> +struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + using LayoutW = cutlass::layout::RowMajor; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + using ElementW = typename Gemm::ElementA; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 10; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + cutlass::DeviceAllocation block_W; + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + cutlass::TensorRef ref_W(block_W.get(), LayoutW::packed({1, N})); + // printf("ref_D:"); + // for (int i = 0; i < 5; ++i) { + // printf("%f ", block_A.get()[i]); + // } + // printf("\nstride:"); + // auto stride = ref_D.stride(); + // // auto layout = ref_A.layout(); + // for (int i = 0; i < 3; ++i) { + // printf("%d ", ref_D.stride(i)); + // } + // printf("\n"); + // printf("\nshape"); + // for (int i = 0; i < 2; ++i) { + // printf("%d ", ref_A.layout(i)); + // } + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + syclcompat::wait(); + + ElementOutput *ptr_refD = + (ElementOutput *)std::malloc(M * N * L * sizeof(ElementOutput)); + syclcompat::memcpy(ptr_refD, block_ref_D.get(), + M * N * L * sizeof(ElementOutput)); + syclcompat::wait(); + ElementW *ptr_wgt = + (ElementW *)std::malloc(N * L * sizeof(ElementW)); + syclcompat::memcpy(ptr_wgt, block_W.get(), + N * L * sizeof(ElementW)); + syclcompat::wait(); + // printf("ptr_ref:\n"); + + // for (int m = 0; m < M; ++m) { + // for (int nn = 0; nn < N / 16; ++nn) { + // printf("%04d:(%03d) ", m, nn * 16); + // for (int n = 0; n < 16; ++n) { + // printf("%5.1f ", ptr_refD[m * N + nn * 16 + n]); + // } + // printf("\n"); + // } + // } + + constexpr float eps = 1e-5; + float p2[M * N]; + float p2sum[M]; + for (int l = 0; l < L; l++) { + for (int m = 0; m < M; m++) { + float pow2_sum = (float)0; + for (int n = 0; n < N; n++) { + p2[m * N + n] = pow(ptr_refD[l * M * N + m * N + n], 2); + pow2_sum += p2[m * N + n]; + } + p2sum[m] = pow2_sum; + float rms = 1.0f / sqrt(pow2_sum / (float)N + eps); + + for (int n = 0; n < N; n++) { + ptr_refD[l * M * N + m * N + n] = ptr_refD[l * M * N + m * N + n] * rms * (float)ptr_wgt[n]; + } + } + } + // printf("p2:\n"); + // for (int m = 0; m < M; ++m) { + // for (int nn = 0; nn < N / 16; ++nn) { + // printf("%4d:(%2d) ", m, nn * 16); + // for (int n = 0; n < 16; ++n) { + // printf("%5.1f ", p2[m * N + nn * 16 + n]); + // } + // printf("\n"); + // } + // } + // printf("ptr_wgt:\n"); + // for (int nn = 0; nn < 4; ++nn) { + // printf("%d: ", nn * 16); + // for (int n = 0; n < 16; ++n) { + // printf("%5.1f ", (float)ptr_wgt[nn * 16 + n]); + // } + // printf("\n"); + // } + // printf("p2sum:\n"); + // for (int mm = 0; mm < M / 16; ++mm) { + // for (int m = 0; m < 16; ++m) { + // printf("%5.1f ", p2sum[mm * 16 + m]); + // } + // printf("\n"); + // } + + // printf("\n"); + // syclcompat::memcpy(block_ref_D.get(), ptr, + // M * N * L * sizeof(ElementOutput)); + // syclcompat::wait(); + + ElementOutput *ptr_D = + (ElementOutput *)std::malloc((size_t)M * N * L * sizeof(ElementOutput)); + syclcompat::memcpy(ptr_D, block_D.get(), + (size_t)M * N * L * sizeof(ElementOutput)); + syclcompat::wait(); + + uint32_t err_cnt = 0; + + float atol = 1e-3; + float rtol = 1e-3; + for (int b = 0; b < L; b++) { + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + int idx = b * M * N + m * N + n; + auto expect = ptr_refD[idx]; + auto val = ptr_D[idx]; + + if (not (std::isinf(ptr_D[idx]) || std::isnan(ptr_D[idx]))) { + if (not isclose(val, expect, atol, rtol)) { + printf("(%d,%d,%d): host: %f and device: %f ratio: %f\n", b, n, m, expect, val, val / expect); + err_cnt++; + } // else{ + // printf("(%d,%d,%d): host: %f and device: %f\n", b, i, m, expect, val); + // } + } else { + printf("(%d,%d,%d): host: %f and device: %f\n", b, n, m, expect, val); + err_cnt++; + } + } + } + } + + std::free(ptr_refD); + std::free(ptr_D); + std::cout << "err count: " << err_cnt + << ", pass rate: " << 100 - (100 * err_cnt / (M * N * L)) << "%" + << std::endl; + return err_cnt == 0; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(M * K * L); + block_B.reset(K * N * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_ref_D.reset(M * N * L); + block_W.reset(N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + initialize_block(block_W, seed + 2020); + // auto a_ptr = block_A.get(); + // for (size_t m = 0; m < M; ++m) { + // for (size_t k = 0; k < K; ++k) { + // a_ptr[m * K + k] = (bfloat16_t)(float)(m * 1000 + k); + // } + // } + // auto b_ptr = block_B.get(); + // for (size_t k = 0; k < K; ++k) { + // for (size_t n = 0; n < N; ++n) { + // if (k == n) + // b_ptr[k * N + n] = (bfloat16_t)1.0f; + // else + // b_ptr[k * N + n] = (bfloat16_t)0.0f; + // } + // } + // printf("initialize done\n"); + } + + cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + using EpilogueArguments = typename Gemm::GemmKernel::EpilogueArguments; + EpilogueArguments epilogue_arguments{ + {options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}; + epilogue_arguments.thread.output_ptr = block_D.get(); + epilogue_arguments.thread.weight_ptr = block_W.get(); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + epilogue_arguments, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + gemm_op.can_implement(arguments); + + gemm_op.initialize(arguments, workspace.get()); + + // Run the GEMM + gemm_op.run(); + + syclcompat::wait(); + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + if (!passed) return cutlass::Status::kErrorInternal; + + if (passed && options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + gemm_op.run(); + } + syclcompat::wait(); + double io = + options.l * + (options.m * options.k * sizeof(ElementA) + options.k * options.n * sizeof(ElementB) + + options.m * options.n * sizeof(ElementOutput)) * + 1e-9; + float cute_time = timer.seconds() / options.iterations; + double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]GB/s, [%4.3f]TF/s, [%6.4f]ms\n", io / cute_time, tflops/cute_time, cute_time*1000); + } + + return cutlass::Status::kSuccess; + } + +}; + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A + using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B + using ElementWeight = bfloat16_t; // <- data type for elements in rmsnorm weight + using ElementOutput = float; // <- data type of elements in output matrix D + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using GmemTiledCopyA = XE_2D_U16x8x16_LD_N; + using GmemTiledCopyB = XE_2D_U16x16x16_LD_V; + + // Workgroup-level tile + using TileShape = Shape<_32, _512, _32>; + + using TiledMma = + typename TiledMMAHelper, Layout, + Layout, Stride<_16, _1, _0>>>::TiledMMA; + + using EpilogueTile = Shape<_16, _32>; + constexpr int PipelineStages = 3; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVC; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue; + + using EpilogueOp = cutlass::epilogue::fusion::LinCombRMSNormRow; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + void, + void, void>; + +// Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementInputA, + cutlass::gemm::TagToStrideA_t, + ElementInputB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + CUTLASS_CHECK(runner.run(options, hw_info)); + + return 0; +} diff --git a/examples/sycl/05_pvc_gemm_with_epilogues/CMakeLists.txt b/examples/sycl/05_pvc_gemm_with_epilogues/CMakeLists.txt index 7ec1f66411..a5e3c95247 100644 --- a/examples/sycl/05_pvc_gemm_with_epilogues/CMakeLists.txt +++ b/examples/sycl/05_pvc_gemm_with_epilogues/CMakeLists.txt @@ -67,3 +67,10 @@ cutlass_example_add_executable( TEST_COMMAND_OPTIONS TEST_BATCHES ) + +cutlass_example_add_executable( + 05_pvc_gemm_with_epilogue_rmsnorm + 05_pvc_gemm_with_epilogue_rmsnorm.cpp + TEST_COMMAND_OPTIONS + TEST_BATCHES +) diff --git a/include/cutlass/epilogue/fusion/operations.hpp b/include/cutlass/epilogue/fusion/operations.hpp index f37eb5b00a..9cd2434575 100644 --- a/include/cutlass/epilogue/fusion/operations.hpp +++ b/include/cutlass/epilogue/fusion/operations.hpp @@ -159,6 +159,20 @@ struct LinCombSoftmaxRow : LinearCombination { }; +// D = rmsnorm((alpha * acc + beta * C)) +template< + class ElementWeight_, + class ElementOutput_, + class ElementCompute_, + class CopyOpR2G_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombRMSNormRow + : LinearCombination { +}; + // D = alpha * acc + beta * C + per-row bias template< class ElementOutput_, diff --git a/include/cutlass/epilogue/fusion/xe_callbacks.hpp b/include/cutlass/epilogue/fusion/xe_callbacks.hpp index 0d199fd383..4447189200 100644 --- a/include/cutlass/epilogue/fusion/xe_callbacks.hpp +++ b/include/cutlass/epilogue/fusion/xe_callbacks.hpp @@ -48,6 +48,7 @@ #include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" #include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" #include "cutlass/epilogue/fusion/xe_visitor_softmax.hpp" +#include "cutlass/epilogue/fusion/xe_visitor_rmsnorm.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -240,6 +241,87 @@ struct FusionCallbacks< using Impl::Impl; }; +// D = rmsnorm(alpha * acc + beta * C) +template< + // int FragmentSize, + class CtaTileShapeMNK, + class EpilogueTile, + class ElementWeight, + class ElementOutput, + class ElementCompute, + class CopyOpR2G, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using XeLinCombRMSNormRow = + Sm90EVT, // rmsnorm(beta * C + (alpha * acc)) + Sm90LinearCombination // beta * C + (alpha * acc) + >; + +template < + // int FragmentSize, + class ElementWeight_, + class ElementOutput_, + class ElementCompute_, + class ElementSource_, + class ElementScalar_, + class CopyOpR2G_, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::IntelPVCEpilogue, + fusion::LinCombRMSNormRow, + CtaTileShapeMNK, + EpilogueTile +> : XeLinCombRMSNormRow { + + using ElementWeight = ElementWeight_; + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementSource = ElementSource_; + using ElementScalar = ElementScalar_; + using Impl = XeLinCombRMSNormRow::type, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, CopyOpR2G_, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinCombRMSNormRow; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementOutput* output_ptr = nullptr; + + using StrideWeight = Stride<_1, _0, int64_t>; + ElementWeight const* weight_ptr = nullptr; + StrideWeight dWeight = {}; + + operator typename Impl::Arguments() const { + return + { // unary op: activation(beta * C + (alpha * acc)) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + {output_ptr, weight_ptr} // unary args: activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + template< class StrideAux, class CopyOpG2R, diff --git a/include/cutlass/epilogue/fusion/xe_visitor_rmsnorm.hpp b/include/cutlass/epilogue/fusion/xe_visitor_rmsnorm.hpp new file mode 100644 index 0000000000..27e850f26a --- /dev/null +++ b/include/cutlass/epilogue/fusion/xe_visitor_rmsnorm.hpp @@ -0,0 +1,408 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree RMSNorm fusion operation for the Intel PVC epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include +#include "xe_visitor_softmax.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class CtaTileShapeMNK, + class EpilogueTile, + class ElementWeight, + class ElementOutput, + class ElementCompute, + class CopyOpR2G, + FloatRoundStyle RoundStyle +> +struct XeRMSNormRowReduction +{ +public: + static constexpr int FragmentSize = 8; + static constexpr auto Tile_M = get<0>(CtaTileShapeMNK{}); + static constexpr auto Tile_N = get<1>(CtaTileShapeMNK{}); + static constexpr auto Epi_M = get<0>(EpilogueTile{}); + static constexpr auto Epi_N = get<1>(EpilogueTile{}); + static constexpr auto Sg_M = Tile_M / Epi_M; + static constexpr auto Sg_N = Tile_N / Epi_N; + static constexpr auto Sg_Nums = Sg_M * Sg_N; + + using Trait_Output = Copy_Traits; + using XE_Copy_output = decltype(make_tiled_copy(Copy_Atom{} + .with(static_cast(nullptr),int32_t(0), int32_t(0)), + Layout>>{}, + make_layout(make_shape(get<0>(typename Trait_Output::BlockShape{}), + get<1>(typename Trait_Output::BlockShape{}) / Int{})))); + + struct SharedStorage { }; + + struct Arguments { + ElementOutput* ptr_output; + ElementWeight const*ptr_weight; + // StrideOutput dOutput; + }; + + struct Params { + XE_Copy_output xe_store_output; + ElementWeight const *weight; + int inner_dim; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + XE_Copy_output output = make_tiled_copy(Copy_Atom, ElementOutput>{}.with( + args.ptr_output, M, N), + Layout>>{}, + make_layout(make_shape(get<0>(typename XE_Copy_output::BlockShape{}), + get<1>(typename XE_Copy_output::BlockShape{}) / Int{}))); + return {output, args.ptr_weight, N}; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + auto [M, N, K, L] = problem_shape; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + // Cross CTA reduction is not possible because there is no guarantee that all CTAs run + // concurrently. + // Cross epilogue tile reduction is possible, but re-visiting and applying reduction + // to accumulators is only possible for the current epilogue tile. + auto [epi_M, epi_N] = EpilogueTile{}; + return N <= tile_N; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_HOST_DEVICE + XeRMSNormRowReduction() { } + + CUTLASS_HOST_DEVICE + XeRMSNormRowReduction(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + CUTLASS_DEVICE static void + print_tensor(VTensor &t) { + print(t); + auto t_shape = t.shape(); + auto t_stride = t.stride(); + auto t_rank = rank(t_shape); + auto total = t.size(); + for (auto i = 0; i < total; ++i) { + if ((i % get<0>(t_shape)) == 0) + print("\n%2d: ", i / get<0>(t_shape)); + print("%5.1f ", t[i]); + } + print("\n"); + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + + CUTLASS_DEVICE + ConsumerStoreCallbacks(RTensor&& res_tensor, CoordTensor&& coord, Params const& params) + : res_tensor(cute::forward(res_tensor)), + coord(cute::forward(coord)), + params(params) {} + + RTensor res_tensor; + CoordTensor coord; + Params const& params; + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + + return frg_acc; + } + + template + CUTLASS_DEVICE void + reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { + constexpr float eps = 1e-5; + auto sg = syclcompat::get_nd_item<1>().get_sub_group(); + auto group = syclcompat::get_nd_item<1>().get_group()[0]; + auto group_id = group; + auto sg_group_id = sg.get_group_id(); + auto sg_local_id = sg.get_local_id()[0]; + if(is_last_iteration) { + for(int epi_v = 0; epi_v < visit_results(0).size(); epi_v++) { + res_tensor(epi_v, epi_m, epi_n) = visit_results(0)[epi_v]; + } + // if (cute::thread0()) { + // print("xe_store_output "); + // print(params.xe_store_output); + // print("\n"); + // print("\n"); + // print_tensor(res_tensor); + // for (int i = 0; i < 2; ++i) { + // print("%5.1f ", res_tensor(0, 0, i)); + // } + // print("\n"); + // print("Epi_M "); + // print(Epi_M); + // print("Sg_M "); + // print(Sg_M); + // print("Epi_N "); + // print(Epi_N); + // print("Sg_N "); + // print(Sg_N); + // print("epi_m "); + // print(epi_m); + // print(" epi_n "); + // print(epi_n); + // print("\n"); + // } + + constexpr auto vec_size = min(Epi_M, Sg_N); + constexpr auto vec_folds = Epi_M / vec_size; + auto smem = syclcompat::local_mem(); + Tensor stensor = make_tensor(make_smem_ptr(smem), make_shape(Int{}, Int{}, Int{})); + auto wgt_ptr=params.weight; + // Tensor weight = make_tensor(params.weight, make_shape(Int<16>{})); // add bias offset here + Tensor res = + make_tensor(static_cast(res_tensor).data(), + make_shape(Int{}, Int{}, Int{})); + // int ts = 256; + // int te = ts + 4; + // int bid = 0; + // for (int t = ts; t < te; ++t) { + // if (cute::thread(t, bid)) { + // printf("t%d: ", t); + // print_tensor(res); + // print("\n"); + // } + // sync_fn(); + // } + // print("vec_size "); + // print(vec_size); + // print("\nvec_folds "); + // print(vec_folds); + // print("\nstensor:"); + // print(stensor); + // print("\nres:"); + // print(res); + // print("\n"); + // } + // CUTLASS_PRAGMA_UNROLL + // for (int loop = 0; loop < vec_folds; loop++) { + // auto loop_t = res(_, loop, _); + // Tensor group_max = make_tensor(make_shape(Int{})); + // group_reduce_max(stensor, loop_t, group_max); + // CUTLASS_PRAGMA_UNROLL + // for (int i = 0; i < Epi_N / IntelPVCEpilogue::SubgroupSize; i++) { + // auto element_vec = loop_t(_, i); + // CUTLASS_PRAGMA_UNROLL + // for (int j = 0; j < vec_size; j++) { + // element_vec(j) -= group_max(j); + // } + // } + // } + // square + Tensor pow2_buff = make_tensor_like(res); + CUTLASS_PRAGMA_UNROLL + for (int loop = 0; loop < vec_folds; loop++) { + auto loop_t = res(_, loop, _); + auto pow2_t = pow2_buff(_, loop, _); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Epi_N / IntelPVCEpilogue::SubgroupSize; i++) { + auto x_vec = loop_t(_, i); + auto p2_vec = pow2_t(_, i); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < vec_size; j++) { + p2_vec(j) = x_vec(j) * x_vec(j); + } + } + } + // if (cute::thread0()) { + // print("N: "); + // print(params.inner_dim); + // print("\n"); + // print("pow2: "); + // print(pow2_buff); + // print_tensor(pow2_buff); + // print("\n"); + // } + + // auto gid = syclcompat::get_nd_item<1>().get_global_linear_id(); + // if (cute::thread0()) { + // print("Epi_N "); + // print(Epi_N); + // print("SubgroupSize"); + // print(IntelPVCEpilogue::SubgroupSize); + // } + int gx = syclcompat::global_id::x() % 256; + int gy = syclcompat::global_id::y(); + auto gid = gx / 16 * 32 + gx % 16; // + syclcompat::local_id::y() * syclcompat::local_range::x(); + // const float wgt_per_col = (float)wgt_ptr[gid + IntelPVCEpilogue::SubgroupSize] + // for (int t = ts; t < te; ++t) { + // sync_fn(); + // if (cute::thread(t, bid)) { + // print("gid "); + // print(syclcompat::get_nd_item<1>().get_global_linear_id()); + // print(" z: "); + // print(syclcompat::local_id::z()); + // print(" y: "); + // print(syclcompat::local_id::y()); + // print(" x: "); + // print(syclcompat::local_id::x()); + // print(" wz: "); + // print(syclcompat::work_group_id::z()); + // print(" wy: "); + // print(syclcompat::work_group_id::y()); + // print(" wx: "); + // print(syclcompat::work_group_id::x()); + // print(" gz: "); + // print(syclcompat::global_id::z()); + // print(" gy: "); + // print(syclcompat::global_id::y()); + // print(" gx: "); + // print(syclcompat::global_id::x()); + // print(" "); + // for (int i = 0; i < 4; ++i) { + // printf("%f ", (float)wgt_ptr[t / 16 * 32 + t % 16 + i]); + // } + // printf("\n"); + // } + // sync_fn(); + // } + CUTLASS_PRAGMA_UNROLL + for (int loop = 0; loop < vec_folds; loop++) { + auto loop_t = res(_, loop, _); + auto pow2_t = pow2_buff(_, loop, _); + Tensor group_sum = make_tensor(make_shape(Int{})); + float rev_dim = 1 / (float)params.inner_dim; + group_reduce_sum(stensor, pow2_t, group_sum); + // if (cute::thread0()) { + // print("group_sum: "); + // print_tensor(group_sum); + // print("\n"); + // } + Tensor rms = make_tensor(make_shape(Int{})); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < vec_size; ++i) { + rms(i) = pow(group_sum(i) * rev_dim + eps, -0.5); + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Epi_N / IntelPVCEpilogue::SubgroupSize; i++) { + const float wgt_per_col = (float)wgt_ptr[gid + i * IntelPVCEpilogue::SubgroupSize]; + // const float wgt_per_col = 1.0f; + auto rmsnorm_vec = loop_t(_, i); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < vec_size; j++) { + rmsnorm_vec(j) = rmsnorm_vec(j) * rms(j) * wgt_per_col; + } + } + } + + copy(params.xe_store_output, res_tensor, coord); + } + else { + for(int epi_v = 0; epi_v < visit_results(0).size(); epi_v++) { + res_tensor(epi_v, epi_m, epi_n) = visit_results(0)[epi_v]; + } + } + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + using MmaAtomShape = typename decltype(args.tiled_mma)::AtomShape_MNK; + static constexpr int FragsM = get<0>(EpilogueTile{}) / get<0>(MmaAtomShape()); // A frags per sub_group + static constexpr int FragsN = get<1>(EpilogueTile{}) / get<1>(MmaAtomShape()); // B frags per sub_group + Tensor res = make_tensor(Shape, Int, Int>{}); + + auto [sg_m_coord, sg_n_coord, k_coord, l_offset] = args.tile_coord_mnkl; + auto [M, N, K, L] = args.problem_shape_mnkl; + Tensor mAux_mnl = cute::get_pvc_tensor(make_shape(M,N,L)); + // Tiling is done differently than in epilogue as we get in coordinates of subgroup in kernel + Tensor gAux = local_tile(mAux_mnl, select<0,1>(EpilogueTile{}), make_coord(sg_m_coord,sg_n_coord,l_offset)); + Tensor tCgAux = args.tiled_copy.get_thread_slice(args.thread_idx).partition_D(gAux); + return ConsumerStoreCallbacks( + cute::move(res), + cute::move(tCgAux), + params); + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/xe_mma_rmsnorm.hpp b/include/cutlass/gemm/collective/xe_mma_rmsnorm.hpp deleted file mode 100644 index 5c663ec10a..0000000000 --- a/include/cutlass/gemm/collective/xe_mma_rmsnorm.hpp +++ /dev/null @@ -1,308 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - /* - * This implements W8A8 GEMM (FP8 weights and activations) using FP16 compute as a workaround, - * since current Intel GPUs (e.g., PVC, BMG) lack native FP8 support. - * The kernel converts FP8 inputs to FP16 on-the-fly and performs GEMM using FP16 MMA. - */ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/gemm/dispatch_policy.hpp" - -#include "cute/algorithm/functional.hpp" -#include "cute/atom/mma_atom.hpp" -#include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::collective { -using namespace cute; -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct CollectiveMma, TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, - GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, GmemTiledCopyB_, SmemLayoutAtomB_, - SmemCopyAtomB_, TransformB_> { - // - // Type Aliases - // - using DispatchPolicy = MainloopIntelW8A8; - using WorkgroupTileShape = TileShape_; - using ElementA = ElementA_; - using StrideA = StrideA_; - using ElementB = ElementB_; - using StrideB = StrideB_; - using TiledMma = TiledMma_; - using ElementAccumulator = typename TiledMma::ValTypeC; - using GmemTiledCopyA = GmemTiledCopyA_; - using GmemTiledCopyB = GmemTiledCopyB_; - using SmemLayoutAtomA = SmemLayoutAtomA_; - using SmemLayoutAtomB = SmemLayoutAtomB_; - using SmemCopyAtomA = SmemCopyAtomA_; - using SmemCopyAtomB = SmemCopyAtomB_; - using TransformA = TransformA_; - using TransformB = TransformB_; - using ArchTag = typename DispatchPolicy::ArchTag; - - static_assert(platform::is_same::value, "MainloopIntelW8A8 requires that A and B have same type."); - // TODO: support E5M2 - static_assert(std::is_same_v, "ElementA must be fp8 (E4M3)"); - static_assert(std::is_same_v, "ElementB must be fp8 (E4M3)"); - - static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; - - using MmaAtomShape = typename TiledMma::AtomShape_MNK; - - static constexpr auto BLK_M = get<0>(WorkgroupTileShape{}); - static constexpr auto BLK_N = get<1>(WorkgroupTileShape{}); - static constexpr auto BLK_K = get<2>(WorkgroupTileShape{}); - - static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); - static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); - static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); - - static constexpr auto SG_M = ceil_div(BLK_M, ATOM_M); - static constexpr auto SG_N = ceil_div(BLK_N, ATOM_N); - static constexpr auto SG_K = ceil_div(BLK_K, ATOM_K); - - using SubgroupTileShape = Shape; - static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); - - using CopyThreadShape = Shape<_1, Int>; - - using traits_load_A = Copy_Traits; - using atom_load_A = Copy_Atom; - using val_layout_load_A = decltype(make_layout(shape_div(typename traits_load_A::BlockShape{}, CopyThreadShape{}))); - using Copy_A = decltype(make_tiled_copy(atom_load_A{}, Layout{}, val_layout_load_A{})); - - using traits_load_B = Copy_Traits; - using atom_load_B = Copy_Atom; - using val_layout_load_B = decltype(make_layout(shape_div(typename traits_load_B::BlockShape{}, CopyThreadShape{}))); - using Copy_B = decltype(make_tiled_copy(atom_load_B{}, Layout{}, val_layout_load_B{})); - - // Host side kernel arguments - struct Arguments { - ElementA const* ptr_A; - StrideA dA; - ElementB const* ptr_B; - StrideB dB; - }; - - struct Params { - Copy_A tiled_copy_a; - Copy_B tiled_copy_b; - }; - - // - // Methods - // - - CollectiveMma() = default; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - (void) workspace; - - auto [M,N,K,L] = problem_shape; - - auto mA_mkl = make_tensor(make_gmem_ptr(static_cast(args.ptr_A)), - make_layout(make_shape(M, K, L), args.dA)); - auto mB_nkl = make_tensor(make_gmem_ptr(static_cast(args.ptr_B)), - make_layout(make_shape(N, K, L), args.dB)); - - Copy_A tiled_copy_a{Copy_A{}.with(mA_mkl)}; - Copy_B tiled_copy_b{Copy_B{}.with(mB_nkl)}; - - return Params{tiled_copy_a, tiled_copy_b}; - } - - template - CUTLASS_DEVICE - void convert_E4M3_to_FP16( - Tensor const& in, - Tensor& out) { - - static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); - static_assert(is_rmem::value, "Output tensor for A conversion must come from registers"); - static_assert(cosize_v == cosize_v); - static_assert(size_v == cosize_v); - static_assert(size_v == cosize_v); - - using SrcType = typename EngineIn::value_type; - using DstType = typename EngineOut::value_type; - - static_assert(std::is_same_v, "Expected fp8 (E4M3) input as uint8_t"); - static_assert(std::is_same_v, "Expected fp16 output as half_t"); - - auto const& src = in(_, _, _); - auto const& dst = out(_, _, _); - - SrcType const* pSrc = src.data(); - DstType* pDst = dst.data(); - - constexpr int num_elements = decltype(size(src))::value; - constexpr int vec_size = 16; - // TODO(Codeplay): Move conversion to NumericArrayConverter - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < num_elements / vec_size; ++i) { - // vectorized load - cute::intel::uchar16 src_vec; - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < vec_size; ++j) { - src_vec[j] = pSrc[i * vec_size + j]; - } - // vectorized convert fp8 -> fp16 - cute::intel::ushort16 dst_vec = E4M3_to_FP16_vec16(src_vec); - // vectorized store - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < vec_size; ++j) { - reinterpret_cast(pDst)[i * vec_size + j] = dst_vec[j]; - - } - } - } - - template - CUTLASS_DEVICE - void RMSNorm(Tensor const &in, - Tensor const &w, - Tensor &out) { - using SrcType = typename EngineIn::value_type; - using DstType = typename EngineOut::value_type; - - auto const &src = in(_, _, _); - auto const &dst = out(_, _, _); - - SrcType const *pSrc = src.data(); - DstType * pDst = dst.data(); - - constexpr int num_elements = decltype(size(src))::value; - constexpr int vec_size = 8; - - cute::intel::float8 x2_vec = 0.0f; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < num_elements / vec_size; ++i) { - cute::intel::float8 src_vec = &pSrc[i]; - cute::intel::float8 x2_vec += src_vec * src_vec; - } - scratch[loc_id] = x2_vec; - for (int i = work_group_size / 2; i >0; i >>= 1) { - item.barrier(sycl::access::fence_space::local_space); - if (loc_id < i) { - scratch[loc_id] += scratch[loc_id + i]; - } - } - if (loc_id == 0) { - accum_acc[group_id] = rsqrt(sratch[0] / num_elements + eps); - } - float rms = accum_acc[group_id]; - for (int i = 0; i < num_elements / vec_size; ++i) { - cute::intel::float8 src_vec = pSrc[i]; - cute::intel::float8 wgt_vec = pWgt[i]; - pDst[i] = wgt_vec * src_vec * rms; - } - } - // Perform a subgroup-scoped matrix multiply-accumulate - template - CUTLASS_DEVICE void operator()(FrgTensorD &accum, TensorA gA, TensorB gB, FrgTensorC const &src_accum, - KTileIterator k_tile_iter, int k_tile_count, BlkCoord const &blk_coord, int const &K_start, int thread_idx, - Params const &mainloop) { - (void)blk_coord; - static_assert(is_rmem::value, "D tensor must be rmem resident."); - static_assert(is_rmem::value, "C tensor must be rmem resident."); - - auto thr_copy_A = mainloop.tiled_copy_a.get_slice(thread_idx); - auto thr_copy_B = mainloop.tiled_copy_b.get_slice(thread_idx); - - // Instantiate the MMA object and get thread slice - TiledMma tiled_mma; - auto sg = syclcompat::get_nd_item<1>().get_sub_group(); - auto first_thread_in_sg_idx = sg.get_group_linear_id() * DispatchPolicy::SubgroupSize; - auto thr_mma = tiled_mma.get_slice(first_thread_in_sg_idx); - - // Partition - Tensor tCgA = thr_mma.partition_A(gA); - Tensor tCgB = thr_mma.partition_B(gB); - - Tensor tCrA = make_tensor(make_fragment_layout(mainloop.tiled_copy_a, tCgA(_,_,_,0).shape())); - Tensor tCrB = make_tensor(make_fragment_layout(mainloop.tiled_copy_b, tCgB(_,_,_,0).shape())); - - Tensor tCrA_fp16 = make_fragment_like(tCrA); - Tensor tCrB_fp16 = make_fragment_like(tCrB); - - // Retile registers for copies - Tensor tArA = thr_copy_A.retile_D(tCrA); - Tensor tBrB = thr_copy_B.retile_D(tCrB); - - // Retile global tile for copies - Tensor tAgA = thr_copy_A.retile_S(tCgA); - Tensor tBgB = thr_copy_B.retile_S(tCgB); - - // - // Mainloop - // - const auto k_start_idx = crd2idx((*k_tile_iter), make_shape(K_start)); - constexpr int barrier_scope = 2; - - CUTLASS_PRAGMA_UNROLL - for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; k_tile++) { - barrier_arrive(barrier_scope); - - // copy fp8 into uint8 - copy(mainloop.tiled_copy_a, tAgA(_,_,_,k_tile), tArA); - copy(mainloop.tiled_copy_b, tBgB(_,_,_,k_tile), tBrB); - - // compute using fp16 - cute::gemm(tiled_mma, tCrA_fp16, tCrB_fp16, accum); - - barrier_wait(barrier_scope); - } - - } - -}; - -} // namespace cutlass::gemm::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// From a185040e63bc194602a62b159caee1cc97d9854a Mon Sep 17 00:00:00 2001 From: Yuankun Shi Date: Mon, 21 Apr 2025 17:55:10 -0700 Subject: [PATCH 3/6] remove comment --- .../05_pvc_gemm_with_epilogue_rmsnorm.cpp | 87 +---------- .../cutlass/epilogue/fusion/xe_callbacks.hpp | 3 +- .../epilogue/fusion/xe_visitor_rmsnorm.hpp | 139 +----------------- 3 files changed, 13 insertions(+), 216 deletions(-) diff --git a/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_rmsnorm.cpp b/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_rmsnorm.cpp index f9869dce97..a0342adae0 100644 --- a/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_rmsnorm.cpp +++ b/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_rmsnorm.cpp @@ -68,7 +68,7 @@ struct Options { bool error; int m, n, k, l, iterations; - float alpha, beta; + float alpha, beta, eps; Options(): help(false), @@ -92,6 +92,7 @@ struct Options { cmd.get_cmd_line_argument("l", l, 1); cmd.get_cmd_line_argument("alpha", alpha, 1.f); cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("eps", eps, 1e-5f); cmd.get_cmd_line_argument("iterations", iterations, 1); } @@ -174,21 +175,6 @@ struct ExampleRunner { cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); cutlass::TensorRef ref_W(block_W.get(), LayoutW::packed({1, N})); - // printf("ref_D:"); - // for (int i = 0; i < 5; ++i) { - // printf("%f ", block_A.get()[i]); - // } - // printf("\nstride:"); - // auto stride = ref_D.stride(); - // // auto layout = ref_A.layout(); - // for (int i = 0; i < 3; ++i) { - // printf("%d ", ref_D.stride(i)); - // } - // printf("\n"); - // printf("\nshape"); - // for (int i = 0; i < 2; ++i) { - // printf("%d ", ref_A.layout(i)); - // } cutlass::reference::device::GemmComplex( {M, N, K}, alpha, @@ -219,29 +205,15 @@ struct ExampleRunner { syclcompat::memcpy(ptr_wgt, block_W.get(), N * L * sizeof(ElementW)); syclcompat::wait(); - // printf("ptr_ref:\n"); - - // for (int m = 0; m < M; ++m) { - // for (int nn = 0; nn < N / 16; ++nn) { - // printf("%04d:(%03d) ", m, nn * 16); - // for (int n = 0; n < 16; ++n) { - // printf("%5.1f ", ptr_refD[m * N + nn * 16 + n]); - // } - // printf("\n"); - // } - // } constexpr float eps = 1e-5; - float p2[M * N]; - float p2sum[M]; + // rowwise rmsnorm for (int l = 0; l < L; l++) { for (int m = 0; m < M; m++) { float pow2_sum = (float)0; for (int n = 0; n < N; n++) { - p2[m * N + n] = pow(ptr_refD[l * M * N + m * N + n], 2); - pow2_sum += p2[m * N + n]; + pow2_sum += pow(ptr_refD[l * M * N + m * N + n], 2); } - p2sum[m] = pow2_sum; float rms = 1.0f / sqrt(pow2_sum / (float)N + eps); for (int n = 0; n < N; n++) { @@ -249,36 +221,6 @@ struct ExampleRunner { } } } - // printf("p2:\n"); - // for (int m = 0; m < M; ++m) { - // for (int nn = 0; nn < N / 16; ++nn) { - // printf("%4d:(%2d) ", m, nn * 16); - // for (int n = 0; n < 16; ++n) { - // printf("%5.1f ", p2[m * N + nn * 16 + n]); - // } - // printf("\n"); - // } - // } - // printf("ptr_wgt:\n"); - // for (int nn = 0; nn < 4; ++nn) { - // printf("%d: ", nn * 16); - // for (int n = 0; n < 16; ++n) { - // printf("%5.1f ", (float)ptr_wgt[nn * 16 + n]); - // } - // printf("\n"); - // } - // printf("p2sum:\n"); - // for (int mm = 0; mm < M / 16; ++mm) { - // for (int m = 0; m < 16; ++m) { - // printf("%5.1f ", p2sum[mm * 16 + m]); - // } - // printf("\n"); - // } - - // printf("\n"); - // syclcompat::memcpy(block_ref_D.get(), ptr, - // M * N * L * sizeof(ElementOutput)); - // syclcompat::wait(); ElementOutput *ptr_D = (ElementOutput *)std::malloc((size_t)M * N * L * sizeof(ElementOutput)); @@ -301,9 +243,7 @@ struct ExampleRunner { if (not isclose(val, expect, atol, rtol)) { printf("(%d,%d,%d): host: %f and device: %f ratio: %f\n", b, n, m, expect, val, val / expect); err_cnt++; - } // else{ - // printf("(%d,%d,%d): host: %f and device: %f\n", b, i, m, expect, val); - // } + } } else { printf("(%d,%d,%d): host: %f and device: %f\n", b, n, m, expect, val); err_cnt++; @@ -341,22 +281,6 @@ struct ExampleRunner { initialize_block(block_B, seed + 2022); initialize_block(block_C, seed + 2021); initialize_block(block_W, seed + 2020); - // auto a_ptr = block_A.get(); - // for (size_t m = 0; m < M; ++m) { - // for (size_t k = 0; k < K; ++k) { - // a_ptr[m * K + k] = (bfloat16_t)(float)(m * 1000 + k); - // } - // } - // auto b_ptr = block_B.get(); - // for (size_t k = 0; k < K; ++k) { - // for (size_t n = 0; n < N; ++n) { - // if (k == n) - // b_ptr[k * N + n] = (bfloat16_t)1.0f; - // else - // b_ptr[k * N + n] = (bfloat16_t)0.0f; - // } - // } - // printf("initialize done\n"); } cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { @@ -368,6 +292,7 @@ struct ExampleRunner { {options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}; epilogue_arguments.thread.output_ptr = block_D.get(); epilogue_arguments.thread.weight_ptr = block_W.get(); + epilogue_arguments.thread.eps = options.eps; typename Gemm::GemmKernel::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, diff --git a/include/cutlass/epilogue/fusion/xe_callbacks.hpp b/include/cutlass/epilogue/fusion/xe_callbacks.hpp index 4447189200..a93cb1bfe4 100644 --- a/include/cutlass/epilogue/fusion/xe_callbacks.hpp +++ b/include/cutlass/epilogue/fusion/xe_callbacks.hpp @@ -298,6 +298,7 @@ struct FusionCallbacks< using StrideWeight = Stride<_1, _0, int64_t>; ElementWeight const* weight_ptr = nullptr; + float eps = 1e-5; StrideWeight dWeight = {}; operator typename Impl::Arguments() const { @@ -313,7 +314,7 @@ struct FusionCallbacks< }, // end binary op {} // ternary args : multiply_add }, // end ternary op - {output_ptr, weight_ptr} // unary args: activation + {output_ptr, weight_ptr, eps} // unary args: activation }; // end unary op } }; diff --git a/include/cutlass/epilogue/fusion/xe_visitor_rmsnorm.hpp b/include/cutlass/epilogue/fusion/xe_visitor_rmsnorm.hpp index 27e850f26a..e3ef28bd69 100644 --- a/include/cutlass/epilogue/fusion/xe_visitor_rmsnorm.hpp +++ b/include/cutlass/epilogue/fusion/xe_visitor_rmsnorm.hpp @@ -77,12 +77,14 @@ struct XeRMSNormRowReduction struct Arguments { ElementOutput* ptr_output; ElementWeight const*ptr_weight; + const float eps; // StrideOutput dOutput; }; struct Params { XE_Copy_output xe_store_output; ElementWeight const *weight; + float eps; int inner_dim; }; @@ -96,7 +98,7 @@ struct XeRMSNormRowReduction Layout>>{}, make_layout(make_shape(get<0>(typename XE_Copy_output::BlockShape{}), get<1>(typename XE_Copy_output::BlockShape{}) / Int{}))); - return {output, args.ptr_weight, N}; + return {output, args.ptr_weight, args.eps, N}; } template @@ -150,22 +152,6 @@ struct XeRMSNormRowReduction return EmptyProducerLoadCallbacks{}; } - template - CUTLASS_DEVICE static void - print_tensor(VTensor &t) { - print(t); - auto t_shape = t.shape(); - auto t_stride = t.stride(); - auto t_rank = rank(t_shape); - auto total = t.size(); - for (auto i = 0; i < total; ++i) { - if ((i % get<0>(t_shape)) == 0) - print("\n%2d: ", i / get<0>(t_shape)); - print("%5.1f ", t[i]); - } - print("\n"); - } - template struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { @@ -189,7 +175,7 @@ struct XeRMSNormRowReduction template CUTLASS_DEVICE void reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { - constexpr float eps = 1e-5; + const float eps = params.eps; auto sg = syclcompat::get_nd_item<1>().get_sub_group(); auto group = syclcompat::get_nd_item<1>().get_group()[0]; auto group_id = group; @@ -199,75 +185,15 @@ struct XeRMSNormRowReduction for(int epi_v = 0; epi_v < visit_results(0).size(); epi_v++) { res_tensor(epi_v, epi_m, epi_n) = visit_results(0)[epi_v]; } - // if (cute::thread0()) { - // print("xe_store_output "); - // print(params.xe_store_output); - // print("\n"); - // print("\n"); - // print_tensor(res_tensor); - // for (int i = 0; i < 2; ++i) { - // print("%5.1f ", res_tensor(0, 0, i)); - // } - // print("\n"); - // print("Epi_M "); - // print(Epi_M); - // print("Sg_M "); - // print(Sg_M); - // print("Epi_N "); - // print(Epi_N); - // print("Sg_N "); - // print(Sg_N); - // print("epi_m "); - // print(epi_m); - // print(" epi_n "); - // print(epi_n); - // print("\n"); - // } constexpr auto vec_size = min(Epi_M, Sg_N); constexpr auto vec_folds = Epi_M / vec_size; auto smem = syclcompat::local_mem(); Tensor stensor = make_tensor(make_smem_ptr(smem), make_shape(Int{}, Int{}, Int{})); auto wgt_ptr=params.weight; - // Tensor weight = make_tensor(params.weight, make_shape(Int<16>{})); // add bias offset here Tensor res = make_tensor(static_cast(res_tensor).data(), make_shape(Int{}, Int{}, Int{})); - // int ts = 256; - // int te = ts + 4; - // int bid = 0; - // for (int t = ts; t < te; ++t) { - // if (cute::thread(t, bid)) { - // printf("t%d: ", t); - // print_tensor(res); - // print("\n"); - // } - // sync_fn(); - // } - // print("vec_size "); - // print(vec_size); - // print("\nvec_folds "); - // print(vec_folds); - // print("\nstensor:"); - // print(stensor); - // print("\nres:"); - // print(res); - // print("\n"); - // } - // CUTLASS_PRAGMA_UNROLL - // for (int loop = 0; loop < vec_folds; loop++) { - // auto loop_t = res(_, loop, _); - // Tensor group_max = make_tensor(make_shape(Int{})); - // group_reduce_max(stensor, loop_t, group_max); - // CUTLASS_PRAGMA_UNROLL - // for (int i = 0; i < Epi_N / IntelPVCEpilogue::SubgroupSize; i++) { - // auto element_vec = loop_t(_, i); - // CUTLASS_PRAGMA_UNROLL - // for (int j = 0; j < vec_size; j++) { - // element_vec(j) -= group_max(j); - // } - // } - // } // square Tensor pow2_buff = make_tensor_like(res); CUTLASS_PRAGMA_UNROLL @@ -284,58 +210,9 @@ struct XeRMSNormRowReduction } } } - // if (cute::thread0()) { - // print("N: "); - // print(params.inner_dim); - // print("\n"); - // print("pow2: "); - // print(pow2_buff); - // print_tensor(pow2_buff); - // print("\n"); - // } - - // auto gid = syclcompat::get_nd_item<1>().get_global_linear_id(); - // if (cute::thread0()) { - // print("Epi_N "); - // print(Epi_N); - // print("SubgroupSize"); - // print(IntelPVCEpilogue::SubgroupSize); - // } int gx = syclcompat::global_id::x() % 256; int gy = syclcompat::global_id::y(); - auto gid = gx / 16 * 32 + gx % 16; // + syclcompat::local_id::y() * syclcompat::local_range::x(); - // const float wgt_per_col = (float)wgt_ptr[gid + IntelPVCEpilogue::SubgroupSize] - // for (int t = ts; t < te; ++t) { - // sync_fn(); - // if (cute::thread(t, bid)) { - // print("gid "); - // print(syclcompat::get_nd_item<1>().get_global_linear_id()); - // print(" z: "); - // print(syclcompat::local_id::z()); - // print(" y: "); - // print(syclcompat::local_id::y()); - // print(" x: "); - // print(syclcompat::local_id::x()); - // print(" wz: "); - // print(syclcompat::work_group_id::z()); - // print(" wy: "); - // print(syclcompat::work_group_id::y()); - // print(" wx: "); - // print(syclcompat::work_group_id::x()); - // print(" gz: "); - // print(syclcompat::global_id::z()); - // print(" gy: "); - // print(syclcompat::global_id::y()); - // print(" gx: "); - // print(syclcompat::global_id::x()); - // print(" "); - // for (int i = 0; i < 4; ++i) { - // printf("%f ", (float)wgt_ptr[t / 16 * 32 + t % 16 + i]); - // } - // printf("\n"); - // } - // sync_fn(); - // } + auto gid = gx / 16 * 32 + gx % 16; CUTLASS_PRAGMA_UNROLL for (int loop = 0; loop < vec_folds; loop++) { auto loop_t = res(_, loop, _); @@ -343,11 +220,6 @@ struct XeRMSNormRowReduction Tensor group_sum = make_tensor(make_shape(Int{})); float rev_dim = 1 / (float)params.inner_dim; group_reduce_sum(stensor, pow2_t, group_sum); - // if (cute::thread0()) { - // print("group_sum: "); - // print_tensor(group_sum); - // print("\n"); - // } Tensor rms = make_tensor(make_shape(Int{})); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < vec_size; ++i) { @@ -356,7 +228,6 @@ struct XeRMSNormRowReduction CUTLASS_PRAGMA_UNROLL for (int i = 0; i < Epi_N / IntelPVCEpilogue::SubgroupSize; i++) { const float wgt_per_col = (float)wgt_ptr[gid + i * IntelPVCEpilogue::SubgroupSize]; - // const float wgt_per_col = 1.0f; auto rmsnorm_vec = loop_t(_, i); CUTLASS_PRAGMA_UNROLL for (int j = 0; j < vec_size; j++) { From ed681cce0e44afb9a9ef00fc3c993351b36ab631 Mon Sep 17 00:00:00 2001 From: Yuankun Shi Date: Mon, 21 Apr 2025 17:59:33 -0700 Subject: [PATCH 4/6] update version info --- .../05_pvc_gemm_with_epilogue_rmsnorm.cpp | 2 +- include/cutlass/epilogue/fusion/xe_visitor_rmsnorm.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_rmsnorm.cpp b/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_rmsnorm.cpp index a0342adae0..3594bee9b2 100644 --- a/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_rmsnorm.cpp +++ b/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_rmsnorm.cpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/include/cutlass/epilogue/fusion/xe_visitor_rmsnorm.hpp b/include/cutlass/epilogue/fusion/xe_visitor_rmsnorm.hpp index e3ef28bd69..8a78344619 100644 --- a/include/cutlass/epilogue/fusion/xe_visitor_rmsnorm.hpp +++ b/include/cutlass/epilogue/fusion/xe_visitor_rmsnorm.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without From 5d715020b184ecc93523c7a269cf68b48b9a84e1 Mon Sep 17 00:00:00 2001 From: Yuankun Shi Date: Tue, 20 May 2025 10:28:55 -0700 Subject: [PATCH 5/6] align with latest api change --- .../05_pvc_gemm_with_epilogue_rmsnorm.cpp | 6 +++--- .../cutlass/epilogue/fusion/xe_callbacks.hpp | 2 +- .../epilogue/fusion/xe_visitor_rmsnorm.hpp | 21 +++++++++++-------- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_rmsnorm.cpp b/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_rmsnorm.cpp index 3594bee9b2..5656b4a918 100644 --- a/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_rmsnorm.cpp +++ b/examples/sycl/05_pvc_gemm_with_epilogues/05_pvc_gemm_with_epilogue_rmsnorm.cpp @@ -328,7 +328,7 @@ struct ExampleRunner { gemm_op.run(); } syclcompat::wait(); - double io = + double io = options.l * (options.m * options.k * sizeof(ElementA) + options.k * options.n * sizeof(ElementB) + options.m * options.n * sizeof(ElementOutput)) * @@ -404,8 +404,8 @@ int main(int argc, const char** argv) using EpilogueTile = Shape<_16, _32>; constexpr int PipelineStages = 3; - using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVC; - using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; using EpilogueOp = cutlass::epilogue::fusion::LinCombRMSNormRow; diff --git a/include/cutlass/epilogue/fusion/xe_callbacks.hpp b/include/cutlass/epilogue/fusion/xe_callbacks.hpp index 983a14806d..e264a0bcb7 100644 --- a/include/cutlass/epilogue/fusion/xe_callbacks.hpp +++ b/include/cutlass/epilogue/fusion/xe_callbacks.hpp @@ -272,7 +272,7 @@ template < class EpilogueTile > struct FusionCallbacks< - epilogue::IntelPVCEpilogue, + epilogue::IntelXeXMX16, fusion::LinCombRMSNormRow, CtaTileShapeMNK, EpilogueTile diff --git a/include/cutlass/epilogue/fusion/xe_visitor_rmsnorm.hpp b/include/cutlass/epilogue/fusion/xe_visitor_rmsnorm.hpp index 8a78344619..4616a1ebee 100644 --- a/include/cutlass/epilogue/fusion/xe_visitor_rmsnorm.hpp +++ b/include/cutlass/epilogue/fusion/xe_visitor_rmsnorm.hpp @@ -68,9 +68,9 @@ struct XeRMSNormRowReduction using Trait_Output = Copy_Traits; using XE_Copy_output = decltype(make_tiled_copy(Copy_Atom{} .with(static_cast(nullptr),int32_t(0), int32_t(0)), - Layout>>{}, + Layout>>{}, make_layout(make_shape(get<0>(typename Trait_Output::BlockShape{}), - get<1>(typename Trait_Output::BlockShape{}) / Int{})))); + get<1>(typename Trait_Output::BlockShape{}) / Int{})))); struct SharedStorage { }; @@ -95,9 +95,9 @@ struct XeRMSNormRowReduction auto [M, N, K, L] = problem_shape_MNKL; XE_Copy_output output = make_tiled_copy(Copy_Atom, ElementOutput>{}.with( args.ptr_output, M, N), - Layout>>{}, + Layout>>{}, make_layout(make_shape(get<0>(typename XE_Copy_output::BlockShape{}), - get<1>(typename XE_Copy_output::BlockShape{}) / Int{}))); + get<1>(typename XE_Copy_output::BlockShape{}) / Int{}))); return {output, args.ptr_weight, args.eps, N}; } @@ -193,7 +193,7 @@ struct XeRMSNormRowReduction auto wgt_ptr=params.weight; Tensor res = make_tensor(static_cast(res_tensor).data(), - make_shape(Int{}, Int{}, Int{})); + make_shape(Int{}, Int{}, Int{})); // square Tensor pow2_buff = make_tensor_like(res); CUTLASS_PRAGMA_UNROLL @@ -201,7 +201,7 @@ struct XeRMSNormRowReduction auto loop_t = res(_, loop, _); auto pow2_t = pow2_buff(_, loop, _); CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Epi_N / IntelPVCEpilogue::SubgroupSize; i++) { + for (int i = 0; i < Epi_N / IntelXeXMX16::SubgroupSize; i++) { auto x_vec = loop_t(_, i); auto p2_vec = pow2_t(_, i); CUTLASS_PRAGMA_UNROLL @@ -213,6 +213,9 @@ struct XeRMSNormRowReduction int gx = syclcompat::global_id::x() % 256; int gy = syclcompat::global_id::y(); auto gid = gx / 16 * 32 + gx % 16; + if (cute::thread0()) { + printf("threadx %d blockx %d blockdimx %d\n", ThreadIdxX(), BlockIdxX(), BlockDimX()); + } CUTLASS_PRAGMA_UNROLL for (int loop = 0; loop < vec_folds; loop++) { auto loop_t = res(_, loop, _); @@ -226,8 +229,8 @@ struct XeRMSNormRowReduction rms(i) = pow(group_sum(i) * rev_dim + eps, -0.5); } CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Epi_N / IntelPVCEpilogue::SubgroupSize; i++) { - const float wgt_per_col = (float)wgt_ptr[gid + i * IntelPVCEpilogue::SubgroupSize]; + for (int i = 0; i < Epi_N / IntelXeXMX16::SubgroupSize; i++) { + const float wgt_per_col = (float)wgt_ptr[gid + i * IntelXeXMX16::SubgroupSize]; auto rmsnorm_vec = loop_t(_, i); CUTLASS_PRAGMA_UNROLL for (int j = 0; j < vec_size; j++) { @@ -259,7 +262,7 @@ struct XeRMSNormRowReduction auto [sg_m_coord, sg_n_coord, k_coord, l_offset] = args.tile_coord_mnkl; auto [M, N, K, L] = args.problem_shape_mnkl; - Tensor mAux_mnl = cute::get_pvc_tensor(make_shape(M,N,L)); + Tensor mAux_mnl = cute::get_xe_tensor(make_shape(M,N,L)); // Tiling is done differently than in epilogue as we get in coordinates of subgroup in kernel Tensor gAux = local_tile(mAux_mnl, select<0,1>(EpilogueTile{}), make_coord(sg_m_coord,sg_n_coord,l_offset)); Tensor tCgAux = args.tiled_copy.get_thread_slice(args.thread_idx).partition_D(gAux); From 201b448d89399dc2003a9ddb7908922880b1a111 Mon Sep 17 00:00:00 2001 From: Yuankun Shi Date: Tue, 20 May 2025 10:31:09 -0700 Subject: [PATCH 6/6] remove sycl include --- include/cutlass/epilogue/fusion/xe_visitor_rmsnorm.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/include/cutlass/epilogue/fusion/xe_visitor_rmsnorm.hpp b/include/cutlass/epilogue/fusion/xe_visitor_rmsnorm.hpp index 4616a1ebee..fb2c91a791 100644 --- a/include/cutlass/epilogue/fusion/xe_visitor_rmsnorm.hpp +++ b/include/cutlass/epilogue/fusion/xe_visitor_rmsnorm.hpp @@ -36,7 +36,6 @@ #pragma once #include "cutlass/cutlass.h" -#include #include "xe_visitor_softmax.hpp" /////////////////////////////////////////////////////////////////////////////////////////////////