diff --git a/CHANGELOG.md b/CHANGELOG.md index 9f423fb2da..9ca90d8e9a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,17 @@ # NVIDIA CUTLASS Changelog +## [3.9.2](https://github.com/NVIDIA/cutlass/releases/tag/v3.9.2) (2025-05-03) -## [3.9.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.9.0) (2025-03-20) +* Fixed [Blockwise](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) and [Groupwise](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) GEMM hang issue when problem size K is 128. +* Optimal code generation with CUDA toolkit versions 12.9. + + +## [3.9.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.9.1) (2025-04-30) + +* Fixed Group Gemm hang issue in CUTLASS 3.x +* Improved Hopper [Blockwise](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) and [Groupwise](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) GEMM performance. + +## [3.9.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.9.0) (2025-04-24) * Support for Blackwell SM120 kernels for GeForce GPUs in CUTLASS 3.x API: - Collective mainloops that target for: @@ -13,18 +23,37 @@ - [Blockscaled GEMM with NVFP4 input datatype and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu). - [Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor with scale factor generation](./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu). - [Blockscaled GEMM with mixed input datatype (MXFP8 and MXFP6) and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu). + - [Grouped GEMM with nvfp4 datatype](./examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu). + - [Sparse Blockscaled GEMM with mxfp8 input datatype and BF16 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu). + - [Sparse Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu). * Set of unit tests that demonstrate the usage of both [sparse](./test/unit/gemm/device/sm120_blockscaled_sparse_tensorop_gemm/) and [dense](./test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/) Blackwell SM120 blockscaled GEMM. +* Support for Blackwell SM100 Sparse kernels: + - Collective mainloop that target for + * [SM100 Sparse GEMM](./include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp) +* Set of example that demonstrate the usage of the 3.x API for targeting Blackwell SM100 Sparse GEMM: + - [Sparse GEMM](./examples/83_blackwell_sparse_gemm/83_blackwell_sparse_gemm.cu) + - [Blockscaled Sparse GEMM with NVFP4 input data type](./examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm.cu) + - [Blockscaled Sparse GEMM with mixed input data type (MXFP8 and MXFP4)](./examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu) +* Set of unit tests that demonstrate the usage of [sparse](./test/unit/gemm/device/sm100_sparse_tensorop_gemm) and [blockscaled sparse](./test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm) Blackwell SM100 GEMM. +* A new Multi-head Latent Attention (MLA) for SM100 Blackwell architecture in CUTLASS [example](./examples/77_blackwell_fmha/) covers the flashMLA-like weight-absorbed decoding use-case. +* A new FMHA Backward kernel for SM100 Blackwell architecture extends CUTLASS [example](./examples/77_blackwell_fmha/) to show how the five backward pass MMAs can be fused into a single kernel to achieve high performance. +* A new [distributed GEMM example](./examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu) for SM100 Blackwell architecture. * Enhancement and new support of block-wise and group-wise GEMM for Hopper and Blackwell architectures: - Enhancement of [blockwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) for Hopper architecture. - Enhancement of [groupwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) for Hopper architecture. - - Support for [grouped GEMM with blockwise scaling](./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture. + - Support for [grouped GEMM with blockwise and groupwise scaling](./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture. + - Support for [grouped-wise GEMM](./tools/profiler/src/blockwise_gemm_operation_profiler.cu) in CUTLASS profiler. - Support for [blockwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu) for Blackwell architecture. - Support for [groupwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu) for Blackwell architecture. -* Added support for enhanced kernel performance search in CUTLASS: + - Support for [grouped GEMM with blockwise](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_blockwise.cu) and [groupwise scaling](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_groupwise.cu) for Blackwell architecture. +* Added support for enhanced kernel performance search (auto-tuning) in CUTLASS profiler: - Sorting performance results by GFLOPs/second: Users can now sort the final performance report based on GFLOPs/second, making it easier to identify the most efficient kernels. - Exhaustive search for best kernel performance in GFLOPs/second: The profiler now searches for the best-performing kernel across a range of problem sizes, swizzle sizes, rasterization orders, and dynamic cluster configurations to maximize performance. - Performance search under a fixed GEMM shape: Enables exhaustive tuning within a fixed GEMM shape, exploring various kernel parameters to find the best configuration. - - More detailed introductions and examples to leverage this feature can be found in [profiler.md](./media/docs/profiler.md#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss). + - More detailed introductions and examples to leverage this feature can be found in [profiler.md](./media/docs/cpp/profiler.md#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss). +* Support `void` as the D element in sm100 kernel epilogues. +* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! +* Optimal code generation with CUDA toolkit versions 12.8U1. ## [3.8.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.8.0) (2025-01-25) @@ -40,7 +69,7 @@ - [Pipelines that implement Blackwell specific synchronization](./include/cutlass/pipeline/sm100_pipeline.hpp). - [Cluster launch control API supporting preferred and fallback cluster shapes](./include/cutlass/cluster_launch.hpp). - Data types including NVFP4, MXFP4, MXFP6, and MXFP8 and all their supported element and scale factor types. - - Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](./media/docs/blackwell_cluster_launch_control.md) to implement dynamic persistence scheduling for [GEMMs](./include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](./include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp). + - Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](./media/docs/cpp/blackwell_cluster_launch_control.md) to implement dynamic persistence scheduling for [GEMMs](./include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](./include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp). - Extensions to testbeds and reference check code for unit tests and CUTLASS profiler. * Full support for Blackwell SM100 kernels in CUTLASS 3.x API: - [Blackwell specific kernel layers](./include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that @@ -78,11 +107,11 @@ - A set of new [Hopper grouped GEMM kernels](./examples/69_hopper_mixed_dtype_grouped_gemm/) that support mixed A and B datatypes. - A new [Hopper FP8 GEMM with groupwise scaling](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu). * Documentation updates: - - [Quickstart - instantiating a Blackwell block-scaled GEMM](./media/docs/quickstart.md#instantiating-a-blackwell-gemm-kernel). - - Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/blackwell_functionality.md) - - A new [functionality documentation](./media/docs/functionality.md) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures. + - [Quickstart - instantiating a Blackwell block-scaled GEMM](./media/docs/cpp/quickstart.md#instantiating-a-blackwell-gemm-kernel). + - Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/cpp/blackwell_functionality.md) + - A new [functionality documentation](./media/docs/cpp/functionality.md) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures. - Updates to [compatibility](./README.md#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](./README.md#Target-Architecture). - - Updates to [profiler documentation](./media/docs/profiler.md) for testing mixed input GEMM kernels on Hopper. + - Updates to [profiler documentation](./media/docs/cpp/profiler.md) for testing mixed input GEMM kernels on Hopper. ## [3.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.7.0) (2025-01-11) - [Hopper blockwise scaling FP8 GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) uses 2D scaling tensor, assigning one value per threadblock. This allows a finer-grained scaling to be applied for each output tile per gemm-k iteration. The operands and scaling tensors are loaded from global memory to shared memory using TMA and cp_async, respectively. The scaling is applied inside the mainloop. Details with figures are [here](https://github.com/NVIDIA/cutlass/pull/1932#issue-2645398439). @@ -95,7 +124,7 @@ + Fix `cute::SM80_CP_ASYNC_CACHEALWAYS`, `cute::SM80_CP_ASYNC_CACHEGLOBAL`, `cute::SM80_CP_ASYNC_CACHEALWAYS_ZFILL`, `cute::SM80_CP_ASYNC_CACHEGLOBAL_ZFILL` to avoid implicitly selecting `ZFILL` behavior on predication. + Remove `cute::copy_vec` in favor of `cute::copy_aligned` and `cute::copy(AutoVectorizingCopyWithAssumedAlignment,...)`. + A refactor of default epilogue struct `DefaultEpilogue` [API](./include/cutlass/epilogue/collective/default_epilogue.hpp) to avoid reading non-void `ElementC` value for `ElementC = void` kernel. -- New CUTLASS profiler flags: `profiling-duration`, `min-iterations`, and `kernels-file` documented in [profiler.md](./media/docs/profiler.md#cutlass-profiler). +- New CUTLASS profiler flags: `profiling-duration`, `min-iterations`, and `kernels-file` documented in [profiler.md](./media/docs/cpp/profiler.md#cutlass-profiler). - Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! - Optimal code generation with CUDA toolkit versions 12.6. @@ -109,12 +138,12 @@ - A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](./include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API. - [An improved mixed input GEMM](./examples/55_hopper_mixed_dtype_gemm/README.md) and a [lookup table implementation](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode. - [EVT nodes for Top-K selection and softmax](./include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](./examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu). -- [Programmatic Dependent Launch](./include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](./media/docs/dependent_kernel_launch.md). -- [A new debugging tool, synclog](./include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](./media/docs/utilities.md#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details. +- [Programmatic Dependent Launch](./include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](./media/docs/cpp/dependent_kernel_launch.md). +- [A new debugging tool, synclog](./include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](./media/docs/cpp/utilities.md#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details. - A new TMA-enabled [epilogue](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for grouped GEMM that brings significant performance improvement, as well as its EVT support. - A SIMT-enabled pointer-array [epilogue](./include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp). - A new [Ping-Pong kernel schedule for Grouped GEMM](./include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp) and some other optimizations. -- [A new instantiation strategy for CUTLASS profiler kernels](./python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](./media/docs/profiler.md#instantiating-more-kernels-with-hopper). +- [A new instantiation strategy for CUTLASS profiler kernels](./python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](./media/docs/cpp/profiler.md#instantiating-more-kernels-with-hopper). - A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](./include/cutlass/bfloat16.h) - Fixed use of isnan on Windows for [`half_t`](./test/unit/core/functional.cu). - Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! @@ -124,7 +153,7 @@ - [Minimal SM90 WGMMA + TMA GEMM example in 100 lines of code](./examples/cute/tutorial/wgmma_sm90.cu) - [Exposure of L2 `cache_hint`s in TMA copy atoms](./include/cute/arch/copy_sm90_tma.hpp#L48) -- Exposure of raster order and tile swizzle extent in [CUTLASS library profiler](./media/docs/profiler.md#GEMM), and +- Exposure of raster order and tile swizzle extent in [CUTLASS library profiler](./media/docs/cpp/profiler.md#GEMM), and [example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu). - [TMA store based and EVT supported epilogues](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for [Hopper pointer array batched kernels](./test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu). - A new [`GemmSparseUniversal` API for CUTLASS 2.x Ampere kernels](./include/cutlass/gemm/device/gemm_sparse_universal.h) to enable serial and parallel split-k for sparse tensor cores and new tiny tile sizes to better support LLM inferrence: @@ -137,7 +166,7 @@ - Support for residual add (beta != 0) in convolution kernels. - A new convolution [epilogue](./examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output. - A refactor of [include files throughout CUTLASS core directories](./include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](./test/self_contained_includes/CMakeLists.txt). -- [A guide for setting up VSCode to work well with CUTLASS](./media/docs/ide_setup.md) and [expanded code style guide](./media/docs/programming_guidelines.md). +- [A guide for setting up VSCode to work well with CUTLASS](./media/docs/cpp/ide_setup.md) and [expanded code style guide](./media/docs/cpp/programming_guidelines.md). - Better support for MSVC as a host compiler. - Many performance optimizations, improvements, and bug fixes including fixes for FlashAttention-2. - Optimal code generation with CUDA toolkit versions 12.4 and 12.5u1. @@ -145,7 +174,7 @@ ## [3.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.0) (2024-04-09) - Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](./include/cute/atom/copy_traits_sm90_im2col.hpp) - + Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](./media/docs/gemm_api_3x.md). + + Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](./media/docs/cpp/gemm_api_3x.md). + Support for 1D, 2D, and 3D convolutions in a [rank-agnostic fashion](./include/cutlass/conv/convnd_problem_shape.hpp). + Support for [Fprop](./test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](./test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](./test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms + [CUTLASS profiler support](./python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API. @@ -157,7 +186,7 @@ - 32x and 16x tile sizes are added to CUTLASS 2.x to improve the performance of narrow-tall and wide-short matrices. + [Ampere FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm80.cu) and [NT](./test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu#L227-L301), [Ampere INT8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu#L392-L1342), [Ampere INT4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu#L372-L934). + [Turing FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm75.cu#L55-L394), [Turing INT8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu#L166-L537), [Turing INT4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu#L310-L564). -- Updates to CuTe documentation for [`cute::Tensor<>`](./media/docs/cute/03_tensor.md), [MMA atoms](./media/docs/cute/0t_mma_atom.md), and an overhauled [CuTe GEMM tutorial series](./examples/cute/tutorial). +- Updates to CuTe documentation for [`cute::Tensor<>`](./media/docs/cpp/cute/03_tensor.md), [MMA atoms](./media/docs/cpp/cute/0t_mma_atom.md), and an overhauled [CuTe GEMM tutorial series](./examples/cute/tutorial). - Extensions to CuTe to support [L2 prefetching](./include/cute/algorithm/prefetch.hpp) and [TMA store+reductions](./include/cute/arch/copy_sm90_tma.hpp#L1337). - Remove C++11 requirement on a few CUTLASS 2.x API header files. All CUTLASS files now require C++17. - Fixes to greatly reduce build warnings. @@ -176,7 +205,7 @@ * Beta release of [Group-GEMM](./examples/57_hopper_grouped_gemm) utilizing TMA and WGMMA (requires CUDA 12.3 or above). * [Ampere Sparse GEMM](./examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu) supports Epilogue Visitor Tree (EVT) now. * NamedBarriers usability improvement and list of [ReservedNamedBarriers](./include/cutlass/arch/barrier.h) has been officially released. -* Improved [CuTe documentation](./media/docs/cute/) including improved clarity and depth of [Quickstart](./media/docs/cute/00_quickstart.md), [CuTe Layout](./media/docs/cute/01_layout.md), and [CuTe Layout Algebra](./media/docs/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](./test/unit/cute/core/) also improved. +* Improved [CuTe documentation](./media/docs/cpp/cute/) including improved clarity and depth of [Quickstart](./media/docs/cute/00_quickstart.md), [CuTe Layout](./media/docs/cpp/cute/01_layout.md), and [CuTe Layout Algebra](./media/docs/cpp/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](./test/unit/cute/core/) also improved. ## [3.3](https://github.com/NVIDIA/cutlass/releases/tag/v3.3.0) (2023-10-31) * [Mixed-input Hopper GEMMs](./examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types. @@ -227,7 +256,7 @@ * Epilogue builders. Similar to mainloop builders (see [example 49](./examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu)), epilogue builders aim to generate the best-possible epilogue while exposing incremental opt-ins for greater customization. * Profiler support for overriding kernel and epilogue builder auto schedules for 3.x API kernels, allowing specific policies to be run in the CUTLASS profiler. * Performance optimizations for the [*warp-specialized persistent ping-pong*](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) kernel. -* Changes to the [GEMM API 3.x](./media/docs/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs. +* Changes to the [GEMM API 3.x](./media/docs/cpp/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs. * [FMHA Backward Pass](./examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu) from Meta xFormers. * [Streamk GEMM with Broadcast](./examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu) enables epilogue broadcast with StreamK GEMM. * [Batched B2B GEMM](./examples/13_two_tensor_op_fusion) now can run multiple Back-to-Back GEMM with the same problem size in parallel. @@ -239,10 +268,10 @@ * Updates and bugfixes from the community (thanks!) ## [3.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.0.0) (2023-01-23) -* [CuTe](./media/docs/cute/00_quickstart.md), a [new core library and backend](./include/cute) for CUTLASS 3.0 that defines a single Layout vocabulary type and an associated algebra of layouts for a much more expressive and composable abstraction for tensors, sets of parallel agents, and operations by said agents on tensors. -* [A new conceptual operation hierarchy](./media/docs/cutlass_3x_design.md) that replaces the architecture-centric hierarchy of CUTLASS 2.x and [documentation for CUTLASS 3.0's GEMM API changes](./media/docs/gemm_api_3x.md). -* Strict API backwards compatibility that exposes both 2.x and 3.x API kernels through the same [`device::GemmUniversalAdapter`](./include/cutlass/gemm/device/gemm_universal_adapter.h) and [`kernel::GemmUniversal`](./include/cutlass/gemm/kernel/gemm_universal.hpp) types, allowing users to include both APIs in the same translation units. More information can be found in the [3.x backwards compatibility section](./media/docs/cutlass_3x_backwards_compatibility.md). -* Updates to [Functionality](./media/docs/functionality.md) which directs users on which kernels are supported via CUTLASS-2 and CUTLASS-3. +* [CuTe](./media/docs/cpp/cute/00_quickstart.md), a [new core library and backend](./include/cute) for CUTLASS 3.0 that defines a single Layout vocabulary type and an associated algebra of layouts for a much more expressive and composable abstraction for tensors, sets of parallel agents, and operations by said agents on tensors. +* [A new conceptual operation hierarchy](./media/docs/cpp/cutlass_3x_design.md) that replaces the architecture-centric hierarchy of CUTLASS 2.x and [documentation for CUTLASS 3.0's GEMM API changes](./media/docs/cpp/gemm_api_3x.md). +* Strict API backwards compatibility that exposes both 2.x and 3.x API kernels through the same [`device::GemmUniversalAdapter`](./include/cutlass/gemm/device/gemm_universal_adapter.h) and [`kernel::GemmUniversal`](./include/cutlass/gemm/kernel/gemm_universal.hpp) types, allowing users to include both APIs in the same translation units. More information can be found in the [3.x backwards compatibility section](./media/docs/cpp/cutlass_3x_backwards_compatibility.md). +* Updates to [Functionality](./media/docs/cpp/functionality.md) which directs users on which kernels are supported via CUTLASS-2 and CUTLASS-3. * Updates to [Compatibility](./README.md#compatibility) Section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures and [Target Architecture](./README.md#Target-Architecture). * New warp-specialized GEMM [kernel schedules](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) and [mainloops](./include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters. * Extensions to CUTLASS profiler to support threadblock cluster shapes in library and profiler tile configurations. @@ -420,7 +449,7 @@ * Global memory iterators supporting Fprop, Dgrad, and Wgrad * `MmaMultistage` for implicit GEMM convolution for NVIDIA Ampere architecture * `MmaPipeline` for implicit GEMM convolution for NVIDIA Volta and Turing architectures - * [Documentation](./media/docs/implicit_gemm_convolution.md) describing Implicit GEMM Convolution algorithm and implementation + * [Documentation](./media/docs/cpp/implicit_gemm_convolution.md) describing Implicit GEMM Convolution algorithm and implementation ## [2.3.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.3.0) (2020-09-23) * [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/) @@ -434,7 +463,7 @@ * NVIDIA Ampere GPU Architecture examples and documentation: * [Tensor Float 32](./examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu) and * [Sparse Tensor Cores](./examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu) - * Documentation added on CUTLASS [efficient row-major epilogue](./media/docs/gemm_api.md#efficient-epilogue) + * Documentation added on CUTLASS [efficient row-major epilogue](./media/docs/cpp/gemm_api.md#efficient-epilogue) ## [2.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.2.0) (2020-06-08) * [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/) @@ -454,7 +483,7 @@ * Disabled F16C by default for compatibility - enable on cmake command line with `-DCUTLASS_ENABLE_F16C=ON` ## [2.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.1.0) (2020-04-06) - * BLAS-style host-side API added to [CUTLASS Library](./media/docs/quickstart.md#cutlass-library) + * BLAS-style host-side API added to [CUTLASS Library](./media/docs/cpp/quickstart.md#cutlass-library) * API to launch compiled kernel instances for GEMM and planar complex GEMM * Planar Complex GEMM kernels targeting Volta and Turing Tensor Cores * Computes complex matrix products on matrices stored as disjoint real and imaginary parts @@ -468,10 +497,10 @@ * Encapsulated functionality embodying modern C++11 programming techniques * Optimized containers and data types for efficient, generic, portable device code * Updates to: - * [Quick start guide](./media/docs/quickstart.md) + * [Quick start guide](./media/docs/cpp/quickstart.md) * [Documentation](./README.md#documentation) - * [Utilities](./media/docs/utilities.md) - * [CUTLASS Profiler](./media/docs/profiler.md) + * [Utilities](./media/docs/cpp/utilities.md) + * [CUTLASS Profiler](./media/docs/cpp/profiler.md) * Native Turing Tensor Cores * Efficient GEMM kernels targeting Turing Tensor Cores * Mixed-precision floating point, 8-bit integer, 4-bit integer, and binarized operands diff --git a/CMakeLists.txt b/CMakeLists.txt index 8d913fed5e..df0926dd1c 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -765,6 +765,7 @@ target_include_directories( CUTLASS SYSTEM INTERFACE $ + $ ) install( diff --git a/PUBLICATIONS.md b/PUBLICATIONS.md index 176b42e498..9c89a40f52 100644 --- a/PUBLICATIONS.md +++ b/PUBLICATIONS.md @@ -6,6 +6,8 @@ - ["ParetoQ: Scaling Laws in Extremely Low-bit LLM Quantization"](https://arxiv.org/abs/2502.02631). Zechun Liu, Changsheng Zhao, Hanxian Huang, Sijia Chen, Jing Zhang, Jiawei Zhao, Scott Roy, Lisa Jin, Yunyang Xiong, Yangyang Shi, Lin Xiao, Yuandong Tian, Bilge Soran, Raghuraman Krishnamoorthi, Tijmen Blankevoort, Vikas Chandra. _arXiv_, February 2025. +- ["Generalized Neighborhood Attention: Multi-dimensional Sparse Attention at the Speed of Light"](https://arxiv.org/abs/2504.16922). Ali Hassani, Fengzhe Zhou, Aditya Kane, Jiannan Huang, Chieh-Yun Chen, Min Shi, Steven Walton, Markus Hoehnerbach, Vijay Thakkar, Michael Isaev, Qinsheng Zhang, Bing Xu, Haicheng Wu, Wen-mei Hwu, Ming-Yu Liu, Humphrey Shi. _arXiv_, April 2025. + ## 2024 - ["DeepSeek-V3 Technical Report"](https://arxiv.org/abs/2412.19437). DeepSeek-AI. _arXiv_, December 2024. diff --git a/README.md b/README.md index ed8011e153..24366fa195 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ ![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") -# CUTLASS 3.9.0 +# CUTLASS 3.9.2 -_CUTLASS 3.9.0 - March 2025_ +_CUTLASS 3.9.2 - May 2025_ **This repository fast-follows NVIDIA CUTLASS repository adding SYCL support for Intel GPUs.** The CUDA support is unmodified from upstream and can be used interchangeably. @@ -39,9 +39,9 @@ the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components. -See the [Quick Start Guide](./media/docs/quickstart.md) to get started quickly. +See the [Quick Start Guide](./media/docs/cpp/quickstart.md) to get started quickly. -See the [functionality docs](./media/docs/functionality.md) for a more comprehensive +See the [functionality docs](./media/docs/cpp/functionality.md) for a more comprehensive list of kernel level features, data types, instructions, and minimum supported by CUTLASS on each GPU architecture. @@ -57,18 +57,35 @@ architecture. - [Blockscaled GEMM with NVFP4 input datatype and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu). - [Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor with scale factor generation](./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu). - [Blockscaled GEMM with mixed input datatype (MXFP8 and MXFP6) and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu). + - [Grouped GEMM with nvfp4 datatype](./examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu). + - [Sparse Blockscaled GEMM with mxfp8 input datatype and BF16 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu). + - [Sparse Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu). * Set of unit tests that demonstrate the usage of both [sparse](./test/unit/gemm/device/sm120_blockscaled_sparse_tensorop_gemm/) and [dense](./test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/) Blackwell SM120 blockscaled GEMM. +* Support for Blackwell SM100 Sparse kernels: + - Collective mainloop that target for + * [SM100 Sparse GEMM](./include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp) +* Set of example that demonstrate the usage of the 3.x API for targeting Blackwell SM100 Sparse GEMM: + - [Sparse GEMM](./examples/83_blackwell_sparse_gemm/83_blackwell_sparse_gemm.cu) + - [Blockscaled Sparse GEMM with NVFP4 input data type](./examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm.cu) + - [Blockscaled Sparse GEMM with mixed input data type (MXFP8 and MXFP4)](./examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu) +* Set of unit tests that demonstrate the usage of [sparse](./test/unit/gemm/device/sm100_sparse_tensorop_gemm) and [blockscaled sparse](./test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm) Blackwell SM100 GEMM. +* A new Multi-head Latent Attention (MLA) for SM100 Blackwell architecture in CUTLASS [example](./examples/77_blackwell_fmha/) covers the flashMLA-like weight-absorbed decoding use-case. +* A new FMHA Backward kernel for SM100 Blackwell architecture extends CUTLASS [example](./examples/77_blackwell_fmha/) to show how the five backward pass MMAs can be fused into a single kernel to achieve high performance. +* A new [distributed GEMM example](./examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu) for SM100 Blackwell architecture. * Enhancement and new support of block-wise and group-wise GEMM for Hopper and Blackwell architectures: - Enhancement of [blockwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) for Hopper architecture. - Enhancement of [groupwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) for Hopper architecture. - - Support for [grouped GEMM with blockwise scaling](./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture. + - Support for [grouped GEMM with blockwise and groupwise scaling](./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture. + - Support for [grouped-wise GEMM](./tools/profiler/src/blockwise_gemm_operation_profiler.cu) in CUTLASS profiler. - Support for [blockwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu) for Blackwell architecture. - Support for [groupwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu) for Blackwell architecture. -* Added support for enhanced kernel performance search in CUTLASS: + - Support for [grouped GEMM with blockwise](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_blockwise.cu) and [groupwise scaling](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_groupwise.cu) for Blackwell architecture. +* Added support for enhanced kernel performance search (auto-tuning) in CUTLASS profiler: - Sorting performance results by GFLOPs/second: Users can now sort the final performance report based on GFLOPs/second, making it easier to identify the most efficient kernels. - Exhaustive search for best kernel performance in GFLOPs/second: The profiler now searches for the best-performing kernel across a range of problem sizes, swizzle sizes, rasterization orders, and dynamic cluster configurations to maximize performance. - Performance search under a fixed GEMM shape: Enables exhaustive tuning within a fixed GEMM shape, exploring various kernel parameters to find the best configuration. - - More detailed introductions and examples to leverage this feature can be found in [profiler.md](./media/docs/profiler.md#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss). + - More detailed introductions and examples to leverage this feature can be found in [profiler.md](./media/docs/cpp/profiler.md#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss). +* Support `void` as the D element in sm100 kernel epilogues. Note: CUTLASS 3.x builds are known to be down on Windows platforms for all CUDA toolkits. CUTLASS team is working on a fix. @@ -115,7 +132,7 @@ Layouts can also be combined and manipulated via functional composition, on whic CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates. This greatly simplifies the design and improves code composability and readability. More documentation specific to CuTe can be found in its -[dedicated documentation directory](./media/docs/cute/00_quickstart.md). +[dedicated documentation directory](./media/docs/cpp/cute/00_quickstart.md). # Compatibility @@ -162,6 +179,7 @@ CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be |NVIDIA H100 Tensor Core GPU |9.0|11.8| |NVIDIA H200 Tensor Core GPU |9.0|11.8| |NVIDIA B200 Tensor Core GPU |10.0|12.8| +|NVIDIA GeForce RTX 50x0 series |10.0|12.8| ## Target Architecture @@ -197,7 +215,7 @@ NVIDIA Blackwell GeForce RTX 50 series GPUs. As a result, kernels compiled for Blackwell SM100 architecture with arch conditional features (using `sm100a`) are not compatible with RTX 50 series GPUs. -Please refer to the [functionality documentation](./media/docs/functionality.md) +Please refer to the [functionality documentation](./media/docs/cpp/functionality.md) for details on which kernels require which target architectures. # Documentation @@ -205,22 +223,22 @@ for details on which kernels require which target architectures. CUTLASS is described in the following documents and the accompanying [Doxygen documentation](https://nvidia.github.io/cutlass). -- [Quick Start Guide](./media/docs/quickstart.md) - basics of building and running CUTLASS -- [Functionality](./media/docs/functionality.md) - summarizes functionality available in CUTLASS -- [Efficient GEMM in CUDA](./media/docs/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA -- [CUTLASS 3.x Design](./media/docs/cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components -- [GEMM API 3.x](./media/docs/gemm_api_3x.md) - describes the CUTLASS 3.x GEMM model and C++ template concepts -- [GEMM API 2.x](./media/docs/gemm_api.md) - describes the CUTLASS 2.x GEMM model and C++ template concepts -- [Implicit GEMM Convolution](./media/docs/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS -- [Code Organization](./media/docs/code_organization.md) - describes the organization and contents of the CUTLASS project -- [Terminology](./media/docs/terminology.md) - describes terms used in the code -- [Programming Guidelines](./media/docs/programming_guidelines.md) - guidelines for writing efficient modern CUDA C++ -- [Fundamental types](./media/docs/fundamental_types.md) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays -- [Layouts](./media/docs/layout.md) - describes layouts of matrices and tensors in memory -- [Tile Iterators](./media/docs/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory -- [CUTLASS Profiler](./media/docs/profiler.md) - command-line driven profiling application -- [CUTLASS Utilities](./media/docs/utilities.md) - additional templates used to facilitate rapid development -- [Dependent kernel launch](./media/docs/dependent_kernel_launch.md) - describes a new feature in Hopper which allows overlapping dependent +- [Quick Start Guide](./media/docs/cpp/quickstart.md) - basics of building and running CUTLASS +- [Functionality](./media/docs/cpp/functionality.md) - summarizes functionality available in CUTLASS +- [Efficient GEMM in CUDA](./media/docs/cpp/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA +- [CUTLASS 3.x Design](./media/docs/cpp/cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components +- [GEMM API 3.x](./media/docs/cpp/gemm_api_3x.md) - describes the CUTLASS 3.x GEMM model and C++ template concepts +- [GEMM API 2.x](./media/docs/cpp/gemm_api.md) - describes the CUTLASS 2.x GEMM model and C++ template concepts +- [Implicit GEMM Convolution](./media/docs/cpp/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS +- [Code Organization](./media/docs/cpp/code_organization.md) - describes the organization and contents of the CUTLASS project +- [Terminology](./media/docs/cpp/terminology.md) - describes terms used in the code +- [Programming Guidelines](./media/docs/cpp/programming_guidelines.md) - guidelines for writing efficient modern CUDA C++ +- [Fundamental types](./media/docs/cpp/fundamental_types.md) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays +- [Layouts](./media/docs/cpp/layout.md) - describes layouts of matrices and tensors in memory +- [Tile Iterators](./media/docs/cpp/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory +- [CUTLASS Profiler](./media/docs/cpp/profiler.md) - command-line driven profiling application +- [CUTLASS Utilities](./media/docs/cpp/utilities.md) - additional templates used to facilitate rapid development +- [Dependent kernel launch](./media/docs/cpp/dependent_kernel_launch.md) - describes a new feature in Hopper which allows overlapping dependent kernels in the same stream, and how it is used in CUTLASS. # Resources @@ -240,7 +258,7 @@ projects. Client applications should target CUTLASS's `include/` directory in th paths. CUTLASS unit tests, examples, and utilities can be build with CMake. -The minimum version of CMake is given in the [Quickstart guide](./media/docs/quickstart.md). +The minimum version of CMake is given in the [Quickstart guide](./media/docs/cpp/quickstart.md). Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed on your system. @@ -285,7 +303,7 @@ CUTLASS is arranged as a header-only library along with Utilities, Tools, Exampl and template concepts defined in the CUTLASS project. A detailed explanation of the source code organization may be found in the -[CUTLASS documentation](./media/docs/code_organization.md), but several main components are summarized below. +[CUTLASS documentation](./media/docs/cpp/code_organization.md), but several main components are summarized below. ## CUTLASS Template Library @@ -359,7 +377,7 @@ tools/ The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate basic usage of Core API components and complete tests of the CUTLASS GEMM computations. -Instructions for building and running the Unit tests are described in the [Quickstart guide](./media/docs/quickstart.md). +Instructions for building and running the Unit tests are described in the [Quickstart guide](./media/docs/cpp/quickstart.md). # Performance Profiling @@ -575,9 +593,9 @@ reference_device: Passed ## More Details on Compiling CUTLASS Kernels and CUTLASS Profiler - Please follow the links for more CMake examples on selectively compiling CUTLASS kernels: - - [GEMM CMake Examples](./media/docs/quickstart.md#gemm-cmake-examples) - - [Implicit GEMM convolution CMake Examples](./media/docs/quickstart.md#convolution-cmake-examples) -- [Further details about the CUTLASS Profiler are described here.](./media/docs/profiler.md) + - [GEMM CMake Examples](./media/docs/cpp/quickstart.md#gemm-cmake-examples) + - [Implicit GEMM convolution CMake Examples](./media/docs/cpp/quickstart.md#convolution-cmake-examples) +- [Further details about the CUTLASS Profiler are described here.](./media/docs/cpp/profiler.md) # About diff --git a/examples/04_tile_iterator/tile_iterator.cu b/examples/04_tile_iterator/tile_iterator.cu index fdfaaac9b2..025eb65f86 100644 --- a/examples/04_tile_iterator/tile_iterator.cu +++ b/examples/04_tile_iterator/tile_iterator.cu @@ -34,7 +34,7 @@ addressable memory, and then store it back into addressable memory. TileIterator is a core concept in CUTLASS that enables efficient loading and storing of data to - and from addressable memory. The PredicateTileIterator accepts a ThreadMap type, which defines + and from addressable memory. The PredicatedTileIterator accepts a ThreadMap type, which defines the mapping of threads to a "tile" in memory. This separation of concerns enables user-defined thread mappings to be specified. @@ -124,7 +124,7 @@ __global__ void copy( cudaError_t TestTileIterator(int M, int K) { - // For this example, we chose a <64, 4> tile shape. The PredicateTileIterator expects + // For this example, we chose a <64, 4> tile shape. The PredicatedTileIterator expects // PitchLinearShape and PitchLinear layout. using Shape = cutlass::layout::PitchLinearShape<64, 4>; using Layout = cutlass::layout::PitchLinear; @@ -136,7 +136,7 @@ cudaError_t TestTileIterator(int M, int K) { // dimension then along the strided dimension. using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap; - // Define the PredicateTileIterator, using TileShape, Element, Layout, and ThreadMap types + // Define the PredicatedTileIterator, using TileShape, Element, Layout, and ThreadMap types using Iterator = cutlass::transform::threadblock::PredicatedTileIterator< Shape, Element, Layout, 1, ThreadMap>; diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu index 6fdcc8363f..c9fbd75643 100644 --- a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu @@ -402,7 +402,7 @@ struct Options : MixedDtypeOptions{ void initialize(Options const& options) { auto shape_B = cute::make_shape(options.n, options.k, options.l); - int const scale_k = (options.k + options.g - 1) / options.g; + int const scale_k = cutlass::ceil_div(options.k, options.g); stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B); // Reverse stride here due to swap and transpose @@ -429,7 +429,7 @@ void initialize(Options const& options) { block_zero.reset(scale_k * options.l * options.n); initialize_tensor(block_A, seed + 2022); - initialize_quant_tensor(block_B, seed + 2021); + initialize_tensor(block_B, seed + 2021); initialize_tensor(block_C, seed + 2020); initialize_scale(block_scale, options); initialize_zero(block_zero, options); diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu index cc54080393..dcab4a7a49 100644 --- a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu @@ -318,7 +318,7 @@ struct Options : MixedDtypeOptions { void initialize(Options const& options) { auto shape_B = cute::make_shape(options.n, options.k, options.l); - int const scale_k = (options.k + options.g - 1) / options.g; + int const scale_k = cutlass::ceil_div(options.k, options.g); stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B); // Reverse stride here due to swap and transpose @@ -347,7 +347,7 @@ void initialize(Options const& options) { block_zero.reset(scale_k * options.l * options.n); initialize_tensor(block_A, seed + 2022); - initialize_quant_tensor(block_B, seed + 2021); + initialize_tensor(block_B, seed + 2021); cutlass::unified_encode_int4b(block_B.get(), block_B_modified.get(), block_B.size()); initialize_tensor(block_C, seed + 2020); initialize_scale(block_scale, options); diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu index aa114e74d7..15eb469263 100644 --- a/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu @@ -288,7 +288,7 @@ cutlass::DeviceAllocation -bool initialize_quant_tensor( - cutlass::DeviceAllocation& block, - uint64_t seed = 2023) { - - float scope_min = float(cutlass::platform::numeric_limits::lowest()); - float scope_max = float(cutlass::platform::numeric_limits::max()); - - cutlass::reference::device::BlockFillRandomUniform( - block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); - - return true; -} - template bool initialize_scale( cutlass::DeviceAllocation& block, @@ -232,10 +218,8 @@ bool initialize_scale( float scope_max = 1.0f, scope_min = 1.0f; if (options.mode != MixedDtypeGemmMode::ConvertOnly) { float elt_max_f = float(cutlass::platform::numeric_limits::max()); - const float max_dequant_val = 4.f; - const float min_dequant_val = 0.5f; - scope_max = max_dequant_val / elt_max_f; - scope_min = min_dequant_val / elt_max_f; + scope_max = 2.f; + scope_min = 0.1f; } cutlass::reference::device::BlockFillRandomUniform( block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); diff --git a/examples/65_distributed_gemm/65_distributed_gemm.cu b/examples/65_distributed_gemm/65_distributed_gemm.cu index 2289d62a8a..6509609f9f 100644 --- a/examples/65_distributed_gemm/65_distributed_gemm.cu +++ b/examples/65_distributed_gemm/65_distributed_gemm.cu @@ -120,8 +120,7 @@ #include "helper.h" // Distributed GEMM helpers -#include "util/benchmark.h" -#include "util/device_copy.h" +#include "dist_gemm_helpers.h" using namespace cute; @@ -834,10 +833,10 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDevice(¤t_device_id)); CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); cudaError_t error = cudaGetDeviceProperties(&props, 0); - if (props.major < 9) { + if (props.major != 9 || props.minor != 0) { std::cerr - << "This example requires a GPU of NVIDIA's Hopper Architecture or " - << "later (compute capability 90 or greater)." << std::endl; + << "This example requires a GPU of NVIDIA's Hopper Architecture " + << "(compute capability 90)." << std::endl; return 0; } diff --git a/examples/65_distributed_gemm/README.md b/examples/65_distributed_gemm/README.md index e3c48a9dd5..6bfff53c2f 100644 --- a/examples/65_distributed_gemm/README.md +++ b/examples/65_distributed_gemm/README.md @@ -63,6 +63,10 @@ procedure is the same, simply modify the following line in the example: using TP = _8; ``` +## References +* [Distributed GEMM Blog](https://blog.shi-labs.com/distributed-gemm-88be6a481e2b) +* [Distributed GEMM Talk on CUDA Mode](https://www.youtube.com/watch?v=NHRTCQBZokg) + ## Copyright Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. diff --git a/examples/65_distributed_gemm/REQUIREMENTS.md b/examples/65_distributed_gemm/REQUIREMENTS.md index 4b8cca3b4d..c6288a91af 100644 --- a/examples/65_distributed_gemm/REQUIREMENTS.md +++ b/examples/65_distributed_gemm/REQUIREMENTS.md @@ -17,6 +17,8 @@ Like all other CUTLASS examples, the NVIDIA driver, runtime, and CUDA Toolkit ar This example specifically requires CUDA Toolkit 12.6 or newer, due to some of the necessary CUDA graph APIs. +The minimum CUDA driver version for running this example is [560.28.03](https://docs.nvidia.com/cuda/archive/12.6.0/cuda-toolkit-release-notes/index.html#id5). + ### Hardware / driver settings This example requires Hopper GPUs with NVLink network. diff --git a/examples/65_distributed_gemm/util/device_copy.h b/examples/65_distributed_gemm/util/device_copy.h deleted file mode 100644 index 257800a097..0000000000 --- a/examples/65_distributed_gemm/util/device_copy.h +++ /dev/null @@ -1,84 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -/*! \file - \brief generic device-to-device data movement kernel based for CuTe tensors. - - NOTE: this kernel assigns one element copy to every thread, and is by no means - an efficient way of copying tensors. It should only be used for convenience in - reference checks. - -*/ - -#pragma once - -#include "cute/layout.hpp" -#include "cute/tensor.hpp" -#include "cutlass/cutlass.h" -#include "cutlass/cuda_host_adapter.hpp" - -namespace cutlass { - -template -void device_copy(TensorSource tensor_source, - TensorDestination tensor_destination, - cudaStream_t stream); - - -template -__global__ void device_copy_kernel(TensorSource const tensor_source, - TensorDestination tensor_destination) { - auto linear_idx = blockIdx.x * blockDim.x + threadIdx.x; - using ElementSrc = typename TensorSource::value_type; - using ElementDst = typename TensorDestination::value_type; - NumericConverter converter; - if (linear_idx < size(tensor_source)) { - tensor_destination(linear_idx) = converter(tensor_source(linear_idx)); - } -} - -template -void device_copy(TensorSource tensor_source, - TensorDestination tensor_destination, - cudaStream_t stream) { - - assert(tensor_source.size() == tensor_destination.size()); - - auto numel = tensor_source.size(); - static constexpr int NumThreads = 128; - auto grid_size = cute::ceil_div(numel, NumThreads); - - dim3 grid(grid_size); - dim3 block(NumThreads); - device_copy_kernel<<>>(tensor_source, tensor_destination); -} - -} //namespace cutlass diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu index 1c21678f10..5d4fe1a180 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu @@ -75,11 +75,11 @@ #include "cutlass/util/reference/host/tensor_copy.h" #include "cutlass/util/reference/host/tensor_compare.h" #include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" // Includes from examples directory #include "helper.h" #include "hopper_fp8_commandline.hpp" -#include "reference/host/gemm_with_blockwise_scaling.h" using namespace cute; @@ -123,7 +123,13 @@ using ArchTag = cutlass::arch::Sm90; // T using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster -using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<>; + +using ScaleConfig = decltype(cutlass::detail::sm90_trivial_blockwise_scale_config(TileShape{})); + +using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand +using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand + +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum; using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; @@ -143,8 +149,8 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui using CollectiveMainloopWithBlockWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, - ElementA, LayoutA, AlignmentA, - ElementB, LayoutB, AlignmentB, + ElementA, cute::tuple, AlignmentA, + ElementB, cute::tuple, AlignmentB, ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout< @@ -190,20 +196,22 @@ StrideB stride_B; StrideC stride_C; StrideD stride_D; StrideAux stride_aux; +LayoutSFA layout_SFA; +LayoutSFB layout_SFB; uint64_t seed; +using LayoutScalar = cutlass::layout::PackedVectorLayout; cutlass::HostTensor tensor_A; cutlass::HostTensor tensor_B; cutlass::HostTensor tensor_C; cutlass::HostTensor tensor_D; uint32_t mma_promotion_interval; -cutlass::HostTensor blockscale_tensor_A; -cutlass::HostTensor blockscale_tensor_B; +cutlass::HostTensor blockscale_tensor_A; +cutlass::HostTensor blockscale_tensor_B; cutlass::HostTensor tensor_ref_D; cutlass::HostTensor tensor_aux; cutlass::HostTensor tensor_ref_aux; -using LayoutScalar = cutlass::layout::PackedVectorLayout; cutlass::HostTensor scalar_alpha; cutlass::HostTensor scalar_beta; cutlass::HostTensor scale_A; @@ -342,26 +350,25 @@ bool initialize_scale_tensor( /// Initialize operands to be used in the GEMM and reference GEMM void initialize(const Options &options) { - // Find Block Scaling tensor shapes based on problem shape and TileShape - auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k); - auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{}))); - auto blockscale_m = cute::get<0>(blockscale_shape); - auto blockscale_n = cute::get<1>(blockscale_shape); - auto blockscale_k = cute::get<2>(blockscale_shape); - stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l)); stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l)); stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l)); stride_aux = stride_D; + // Layout SFA and SFB represent logically broadcasting data in CuTe. + // E.g., if Layout SFA has shape ((ScaleGranularityM, M / ScaleGranularityM), (ScaleGraunularityK, K / ScaleGranularityK)) + // and strides ((0, 1), (0, M / ScaleGraunuarlityM)), then each collection of ScaleGranularityM x ScaleGranularityK + // indecies in the tensor map to the same offset. + layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l)); + layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l)); auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); - auto blockscale_a_coord = cutlass::make_Coord(blockscale_m * options.l, blockscale_k); - auto blockscale_b_coord = cutlass::make_Coord(blockscale_k, blockscale_n * options.l); + auto blockscale_a_coord = cutlass::make_Coord(size(filter_zeros(layout_SFA))); + auto blockscale_b_coord = cutlass::make_Coord(size(filter_zeros(layout_SFB))); tensor_A.resize(a_coord); blockscale_tensor_A.resize(blockscale_a_coord); @@ -465,7 +472,9 @@ typename Gemm::Arguments args_from_options(const Options &op stride_B, mma_promotion_interval, blockscale_tensor_A.device_data(), - blockscale_tensor_B.device_data() + layout_SFA, + blockscale_tensor_B.device_data(), + layout_SFB }, { {}, // epilogue.thread @@ -519,12 +528,6 @@ bool verify(const Options &options) { // Compute reference output // - // Block scaling tensors shapes based CTA Block (TileShape) and GEMM Problem shape - auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k); - auto blockscale_m = ceil_div(options.m, get<0>(TileShape{})); - auto blockscale_n = ceil_div(options.n, get<1>(TileShape{})); - auto blockscale_k = ceil_div(options.k, get<2>(TileShape{})); - // Create instantiation for device reference gemm kernel auto A = cute::make_tensor(tensor_A.host_data(), cute::make_layout( @@ -557,28 +560,18 @@ bool verify(const Options &options) { ) ); - auto blockscale_A = cute::make_tensor(blockscale_tensor_A.host_data(), - cute::make_layout( - cute::make_shape(blockscale_m, blockscale_k, options.l), - cute::make_stride(1, blockscale_m, blockscale_m * blockscale_k) - ) - ); - auto blockscale_B = cute::make_tensor(blockscale_tensor_B.host_data(), - cute::make_layout( - cute::make_shape(blockscale_n, blockscale_k, options.l), - cute::make_stride(1, blockscale_n, blockscale_n * blockscale_k) - ) - ); + auto SFA = cute::make_tensor(blockscale_tensor_A.host_data(), layout_SFA); + auto SFB = cute::make_tensor(blockscale_tensor_B.host_data(), layout_SFB); using unused_t = decltype(D); - cutlass::reference::host::GettMainloopParams mainloop_params{ - A, B, // Operand Tensors - blockscale_A, blockscale_B // Blockwise scaling Tensors - }; + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, + decltype(A), + decltype(SFA), + decltype(B), + decltype(SFB) + > mainloop_params{A, SFA, B, SFB}; cutlass::reference::host::GettEpilogueParams< ElementScalar, diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu index b7cdb00a67..096e56a6b8 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu @@ -75,11 +75,11 @@ #include "cutlass/util/reference/host/tensor_copy.h" #include "cutlass/util/reference/host/tensor_compare.h" #include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" // Includes from examples directory #include "helper.h" #include "hopper_fp8_commandline.hpp" -#include "reference/host/gemm_with_groupwise_scaling.h" using namespace cute; @@ -120,55 +120,30 @@ using ElementAccumulator = float; // E using ElementBlockScale = float; // Element type for blockscaling during accumulation using ElementCompute = float; // Element type for epilogue computation -using TileShape_ = Shape<_128,_128,_128>; // This one is just to make the compiler happy with verify()... - -// ScaleGranularity{M,N}: number of {rows in A}/{columns in B} that share the same scaling factor -// Given TileShape = Shape<_128,_128,_128>: -// ScaleGranularityM == 128 and ScaleGranularityN == 128 --> 2Dx2D (the shape of the scaling factor) -// ScaleGranularityM == 1 and ScaleGranularityN == 128 --> 1Dx2D scaling -// ScaleGranularityM == 128 and ScaleGranularityN == 1 --> 2Dx1D scaling -// ScaleGranularityM == 1 and ScaleGranularityN == 1 --> 1Dx1D scaling -template -struct GroupScaleConfig { - using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature - using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag - using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size - using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster - - static constexpr int ScaleGranularityM = ScaleGranularityM_; - static constexpr int ScaleGranularityN = ScaleGranularityN_; - static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; - static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; - - static_assert(size<0>(TileShape{}) == ScaleGranularityM * ScaleMsPerTile, - "FP8 scaling granularity must evenly divide tile shape along M."); - static_assert(size<1>(TileShape{}) == ScaleGranularityN * ScaleNsPerTile, - "FP8 scaling granularity must evenly divide tile shape along N."); - - using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; - using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; - using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux< +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster + +constexpr int ScaleGranularityM = 1; +constexpr int ScaleGranularityN = 128; +constexpr int ScaleGranularityK = 128; + +constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; +constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; + +using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig; + +using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand +using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand + +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum; +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; +using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux< LayoutAux, cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementC>; -}; -using GroupScale1D1DConfig = GroupScaleConfig< 1, 1>; -using GroupScale1D2DConfig = GroupScaleConfig< 1, size<1>(TileShape_{})>; -using GroupScale2D1DConfig = GroupScaleConfig(TileShape_{}), 1>; -using GroupScale2D2DConfig = GroupScaleConfig(TileShape_{}), size<1>(TileShape_{})>; - -template -struct GroupScaleGemm { - using ArchTag = typename ScheduleConfig::ArchTag; - using OperatorClass = typename ScheduleConfig::OperatorClass; - using TileShape = typename ScheduleConfig::TileShape; - using ClusterShape = typename ScheduleConfig::ClusterShape; - using KernelSchedule = typename ScheduleConfig::KernelSchedule; - using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; - using EpilogueTileType = typename ScheduleConfig::EpilogueTileType; - using FusionOperation = typename ScheduleConfig::FusionOperation; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, @@ -179,10 +154,10 @@ struct GroupScaleGemm { FusionOperation >::CollectiveOp; - using CollectiveMainloopWithGroupWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder< +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, - ElementA, LayoutA, AlignmentA, - ElementB, LayoutB, AlignmentB, + ElementA, cute::tuple, AlignmentA, + ElementB, cute::tuple, AlignmentB, ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout< @@ -191,38 +166,26 @@ struct GroupScaleGemm { KernelSchedule >::CollectiveOp; - using GemmKernelDefault = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloopWithGroupWiseScaling, - CollectiveEpilogue - >; - using GemmKernelStreamK = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloopWithGroupWiseScaling, - CollectiveEpilogue, - cutlass::gemm::StreamKScheduler +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler >; - using GemmDefault = cutlass::gemm::device::GemmUniversalAdapter; - using GemmStreamK = cutlass::gemm::device::GemmUniversalAdapter; -}; - -using GroupScale1D1DGemm = GroupScaleGemm; -using GroupScale1D2DGemm = GroupScaleGemm; -using GroupScale2D1DGemm = GroupScaleGemm; -using GroupScale2D2DGemm = GroupScaleGemm; +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; // Extract information from Gemm kernel. -using EpilogueOutputOp = typename GroupScale1D1DGemm::GemmDefault::EpilogueOutputOp; +using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; using ElementScalar = typename EpilogueOutputOp::ElementScalar; using ElementAmax = typename EpilogueOutputOp::ElementAmax; using ActivationFunctor = typename EpilogueOutputOp::ActivationFn; -using StrideA = typename GroupScale1D1DGemm::GemmDefault::GemmKernel::StrideA; -using StrideB = typename GroupScale1D1DGemm::GemmDefault::GemmKernel::StrideB; -using StrideC = typename GroupScale1D1DGemm::GemmDefault::GemmKernel::StrideC; -using StrideD = typename GroupScale1D1DGemm::GemmDefault::GemmKernel::StrideD; +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; using StrideAux = StrideD; constexpr bool IsDFp8 = @@ -242,20 +205,23 @@ StrideB stride_B; StrideC stride_C; StrideD stride_D; StrideAux stride_aux; +LayoutSFA layout_SFA; +LayoutSFB layout_SFB; uint64_t seed; +using LayoutScalar = cutlass::layout::PackedVectorLayout; + cutlass::HostTensor tensor_A; cutlass::HostTensor tensor_B; cutlass::HostTensor tensor_C; cutlass::HostTensor tensor_D; uint32_t mma_promotion_interval; -cutlass::HostTensor blockscale_tensor_A; -cutlass::HostTensor blockscale_tensor_B; +cutlass::HostTensor blockscale_tensor_A; +cutlass::HostTensor blockscale_tensor_B; cutlass::HostTensor tensor_ref_D; cutlass::HostTensor tensor_aux; cutlass::HostTensor tensor_ref_aux; -using LayoutScalar = cutlass::layout::PackedVectorLayout; cutlass::HostTensor scalar_alpha; cutlass::HostTensor scalar_beta; cutlass::HostTensor scale_A; @@ -392,32 +358,25 @@ bool initialize_scale_tensor( } /// Initialize operands to be used in the GEMM and reference GEMM -template void initialize(const Options &options) { - using TileShape = typename GroupScaleConfig::TileShape; - const int ScaleGranularityM = GroupScaleConfig::ScaleGranularityM; - const int ScaleGranularityN = GroupScaleConfig::ScaleGranularityN; - assert(options.m % ScaleGranularityM == 0); assert(options.n % ScaleGranularityN == 0); - // Find Group Scaling tensor shapes based on `ScaleGranularityM`, problem shape, and TileShape - auto groupscale_m = ceil_div(options.m, ScaleGranularityM); - auto groupscale_n = ceil_div(options.n, ScaleGranularityN); - auto blockscale_k = ceil_div(options.k, cute::get<2>(TileShape{})); - stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l)); stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l)); stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l)); stride_aux = stride_D; + layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l)); + layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l)); + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); - auto groupscale_a_coord = cutlass::make_Coord(groupscale_m * options.l, blockscale_k); - auto groupscale_b_coord = cutlass::make_Coord(groupscale_n * options.l, blockscale_k); + auto groupscale_a_coord = cutlass::make_Coord(size(filter_zeros(layout_SFA))); + auto groupscale_b_coord = cutlass::make_Coord(size(filter_zeros(layout_SFB))); tensor_A.resize(a_coord); tensor_B.resize(b_coord); @@ -522,7 +481,9 @@ GemmArguments args_from_options(const Options &options) stride_B, mma_promotion_interval, blockscale_tensor_A.device_data(), - blockscale_tensor_B.device_data() + layout_SFA, + blockscale_tensor_B.device_data(), + layout_SFB }, { {}, // epilogue.thread @@ -572,19 +533,10 @@ GemmArguments args_from_options(const Options &options) } /// Don't know why the compiler does not like verify() being templated... -bool verify(const Options &options, const int ScaleMsPerTile, const int ScaleNsPerTile) { +bool verify(const Options &options) { // // Compute reference output // - const int ScaleGranularityM = get<0>(TileShape_{}) / ScaleMsPerTile; - const int ScaleGranularityN = get<1>(TileShape_{}) / ScaleNsPerTile; - - // Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape - auto blockscale_m = ceil_div(options.m, get<0>(TileShape_{})); - auto blockscale_n = ceil_div(options.n, get<1>(TileShape_{})); - auto blockscale_k = ceil_div(options.k, get<2>(TileShape_{})); - auto groupscale_m = ceil_div(options.m, ScaleGranularityM); - auto groupscale_n = ceil_div(options.n, ScaleGranularityN); // Create instantiation for device reference gemm kernel auto A = cute::make_tensor(tensor_A.host_data(), @@ -618,28 +570,18 @@ bool verify(const Options &options, const int ScaleMsPerTile ) ); - auto blockscale_A = cute::make_tensor(blockscale_tensor_A.host_data(), - cute::make_layout( - cute::make_shape(groupscale_m, blockscale_k, options.l), - cute::make_stride(1, groupscale_m, groupscale_m * blockscale_k) - ) - ); - auto blockscale_B = cute::make_tensor(blockscale_tensor_B.host_data(), - cute::make_layout( - cute::make_shape(groupscale_n, blockscale_k, options.l), - cute::make_stride(1, groupscale_n, groupscale_n * blockscale_k) - ) - ); + auto SFA = cute::make_tensor(blockscale_tensor_A.host_data(), layout_SFA); + auto SFB = cute::make_tensor(blockscale_tensor_B.host_data(), layout_SFB); using unused_t = decltype(D); - cutlass::reference::host::GettMainloopParams mainloop_params{ - A, B, // Operand Tensors - blockscale_A, blockscale_B // Groupwise scaling Tensors - }; + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, + decltype(A), + decltype(SFA), + decltype(B), + decltype(SFB) + > mainloop_params{A, SFA, B, SFB}; cutlass::reference::host::GettEpilogueParams< ElementScalar, @@ -713,14 +655,7 @@ bool verify(const Options &options, const int ScaleMsPerTile } /// Execute a given example GEMM computation -template -int run(Options &options) -{ - using TileShape = typename GroupScaleConfig::TileShape; - const int ScaleGranularityM = GroupScaleConfig::ScaleGranularityM; - const int ScaleGranularityN = GroupScaleConfig::ScaleGranularityN; - const int ScaleMsPerTile = GroupScaleConfig::ScaleMsPerTile; - const int ScaleNsPerTile = GroupScaleConfig::ScaleNsPerTile; +int run(Options &options) { bool skip = false; std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; @@ -747,7 +682,7 @@ int run(Options &options) if (!skip) std::cout << " Running... " << std::endl; else return -1; - initialize(options); + initialize(options); // Instantiate CUTLASS kernel depending on templates Gemm gemm; @@ -773,7 +708,7 @@ int run(Options &options) // Check if output from CUTLASS kernel and reference kernel are equal or not Result result; if (options.verify) { - result.passed = verify(options, ScaleMsPerTile, ScaleNsPerTile); + result.passed = verify(options); std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; } @@ -860,28 +795,7 @@ int main(int argc, char const **args) { #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) bool passed = true; - std::cout << "Basic split-K GEMM kernel" << std::endl; - passed &= run(options); - std::cout << std::endl; - passed &= run(options); - std::cout << std::endl; - passed &= run(options); - std::cout << std::endl; - passed &= run(options); - std::cout << std::endl; - - std::cout << std::endl; - - std::cout << "StreamK GEMM kernel" << std::endl; - passed &= run(options); - std::cout << std::endl; - passed &= run(options); - std::cout << std::endl; - passed &= run(options); - std::cout << std::endl; - passed &= run(options); - std::cout << std::endl; - + passed = run(options); if (!passed) return -1; #endif diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_blockwise_scaling.h b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_blockwise_scaling.h deleted file mode 100644 index 8904060cba..0000000000 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_blockwise_scaling.h +++ /dev/null @@ -1,504 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Reference implementation for GETT in host-side code. -*/ - -#pragma once - -///////////////////////////////////////////////////////////////////////////////////////////////// -#include "cutlass/gemm/gemm.h" -#include "cutlass/complex.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/epilogue/thread/activation.h" -#include "cutlass/relatively_equal.h" -#include -#include "cute/tensor.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::reference::host { - -template -struct ElementTraits { - using type = T; -}; - -template -struct ElementTraits().get()), void> > > { - using type = decltype(std::declval().get()); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template< - class ElementAccumulator_, - class TensorA_, // (M, K, L) - class TensorB_, // (N, K, L) - class TensorScaleA_, // (m, k, L) - class TensorScaleB_, // (n, k, L) - class TileShape_ -> -struct GettMainloopParams { - using ElementAccumulator = ElementAccumulator_; - using TensorA = TensorA_; - using TensorB = TensorB_; - using EngineA = typename TensorA::engine_type; - using LayoutA = typename TensorA::layout_type; - using EngineB = typename TensorB::engine_type; - using LayoutB = typename TensorB::layout_type; - - using TensorScaleA = TensorScaleA_; - using TensorScaleB = TensorScaleB_; - using TileShape = TileShape_; - using EngineScaleA = typename TensorScaleA::engine_type; - using EngineScaleB = typename TensorScaleB::engine_type; - - TensorA A{}; - TensorB B{}; - TensorScaleA ScaleA{}; - TensorScaleB ScaleB{}; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -template< - class ElementScalar_, - class ElementScalingFactor_, - class ElementAccumulator_, - class ElementCompute_, - class TensorC_, // (M, N, L) - class TensorD_, // (M, N, L) - class VectorBias_ = TensorD_, // (M, 1) - class TensorAux_ = TensorD_, // (M, N, L) - class VectorAlpha_ = TensorD_, // (M, 1) - class VectorBeta_ = VectorAlpha_, // (M, 1) - class ActivationFunctor_ = cutlass::epilogue::thread::Identity, - class BiasBinaryOp_ = cutlass::plus, - bool PerColumnBias_ = false -> -struct GettEpilogueParams { - using ElementScalar = ElementScalar_; - using ElementScalingFactor = ElementScalingFactor_; - using ElementAccumulator = ElementAccumulator_; - using ElementCompute = ElementCompute_; - using TensorC = TensorC_; - using TensorD = TensorD_; - using TensorAux = TensorAux_; - using VectorBias = VectorBias_; - using VectorAlpha = VectorAlpha_; - using VectorBeta = VectorBeta_; - using ActivationFunctor = ActivationFunctor_; - using BiasBinaryOp = BiasBinaryOp_; - - using EngineC = typename TensorC::engine_type; - using LayoutC = typename TensorC::layout_type; - using EngineD = typename TensorD::engine_type; - using LayoutD = typename TensorD::layout_type; - static constexpr bool PerColumnBias = PerColumnBias_; - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - - TensorC C{}; - TensorD D{}; - VectorBias Bias{}; - TensorAux Aux{}; - VectorAlpha Valpha{}; - VectorBeta Vbeta{}; - ElementCompute st = ElementCompute(1); - - ElementAccumulator* abs_max_D = nullptr; - ElementAccumulator* abs_max_Aux = nullptr; - - ElementScalingFactor scale_a = ElementScalingFactor(1); - ElementScalingFactor scale_b = ElementScalingFactor(1); - ElementScalingFactor scale_c = ElementScalingFactor(1); - ElementScalingFactor scale_d = ElementScalingFactor(1); - ElementScalingFactor scale_aux = ElementScalingFactor(1); - - bool beta_per_channel_scaling = false; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GETT - General Tensor-Tensor contraction reference kernel with Blockwise scaling -template < - class MainloopParams, - class EpilogueParams -> -void Gett( - MainloopParams const& mainloop_params, - EpilogueParams const& epilogue_params) -{ - - static int constexpr kBlockM = cute::get<0>(typename MainloopParams::TileShape{}); - static int constexpr kBlockN = cute::get<1>(typename MainloopParams::TileShape{}); - // printf("mainloop_params.ScaleA.layout()"); cute::print(mainloop_params.ScaleA.layout()); printf("\n"); - // printf("mainloop_params.ScaleB.layout()"); cute::print(mainloop_params.ScaleB.layout()); printf("\n"); - -#if defined(_OPENMP) - #pragma omp parallel for collapse(3) -#endif - for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { - for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { - for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { - typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN]; - gett_mainloop(mainloop_params, m, n, l, acc); - gett_epilogue(epilogue_params, m, n, l, acc); - } - } - } -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GETT - Mainloop -template -void gett_mainloop( - MainloopParams const& mainloop_params, - int64_t m, - int64_t n, - int64_t l, - ElementAccumulator (&acc)[kBlockM][kBlockN]) -{ - - static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B"); - static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B"); - - using cute::raw_pointer_cast; - - using ElementA = typename ElementTraits::type; - using ElementB = typename ElementTraits::type; - using ElementBlockScaleA = typename ElementTraits::type; - using ElementBlockScaleB = typename ElementTraits::type; - - using RingOp = multiply_add; - RingOp fma_op; - - multiplies scale_op; - - static int constexpr kBlockK = cute::get<2>(typename MainloopParams::TileShape{});; - - // Tempo accumulators to seperate blockwise accumulation - typename MainloopParams::ElementAccumulator acc_temp[kBlockM][kBlockN]; - - // Zero out accumulators - for (int m_b = 0; m_b < kBlockM; ++m_b) { - for (int n_b = 0; n_b < kBlockN; ++n_b) { - acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity - acc_temp[m_b][n_b] = ElementAccumulator(0); - } - } - - int64_t block_m = m / kBlockM; - int64_t block_n = n / kBlockN; - cute::Tensor blockscale_A = mainloop_params.ScaleA(block_m, _, l); - cute::Tensor blockscale_B = mainloop_params.ScaleB(block_n, _, l); - - // Compute on this k-block - for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) { - - // Load Blockwise scaling factor from blockscale Tensors for A and B - int64_t block_k = k / kBlockK; - ElementBlockScaleA scale_a = blockscale_A[block_k]; - ElementBlockScaleB scale_b = blockscale_B[block_k]; - - // Load A - ElementAccumulator a_frag[kBlockM]; - for (int m_b = 0; m_b < kBlockM; ++m_b) { - if (m + m_b < cute::size<0>(mainloop_params.A.layout())) { - // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. - a_frag[m_b] = static_cast(ElementA(mainloop_params.A(m + m_b, k, l))); - } else { - a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity - } - } - - // Load B - ElementAccumulator b_frag[kBlockN]; - for (int n_b = 0; n_b < kBlockN; ++n_b) { - if (n + n_b < cute::size<0>(mainloop_params.B.layout())) { - // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. - b_frag[n_b] = static_cast(ElementB(mainloop_params.B(n + n_b, k, l))); - } else { - b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity - } - } - - // do compute - for (int m_b = 0; m_b < kBlockM; ++m_b) { - for (int n_b = 0; n_b < kBlockN; ++n_b) { - acc_temp[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc_temp[m_b][n_b]); - } - } - - // Apply Blockwise-scaling at kBlockK boundary - // (a) Apply block scaling factors on the partial accumulated results (acc_temp) at the kBlocK boundary - // (b) Zero-out partial temporary (acc_temp), - // (c) Update permanent (accu) - if ((k+1) % kBlockK == 0) { - for (int m_b = 0; m_b < kBlockM; ++m_b) { - for (int n_b = 0; n_b < kBlockN; ++n_b) { - ElementAccumulator blockwise_scaled_accum = acc_temp[m_b][n_b] * scale_a * scale_b; - acc[m_b][n_b] = blockwise_scaled_accum + acc[m_b][n_b]; - acc_temp[m_b][n_b] = ElementAccumulator(0); - } - } - } - - } -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GETT - Epilogue -template -void gett_epilogue( - EpilogueParams const& epilogue_params, - int64_t m, - int64_t n, - int64_t l, - ElementAccumulator (&acc)[kBlockM][kBlockN]) -{ - static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B"); - static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B"); - - using cute::raw_pointer_cast; - - using ElementCompute = typename EpilogueParams::ElementCompute; - using ElementC = typename EpilogueParams::TensorC::value_type; - using ElementD = typename EpilogueParams::TensorD::value_type; - using ElementAux = typename EpilogueParams::TensorAux::value_type; - using ElementBias = typename EpilogueParams::VectorBias::value_type; - using ElementScalar = typename EpilogueParams::ElementScalar; - using ElementScalingFactor = typename EpilogueParams::ElementScalingFactor; - using ActivationFunctor = typename EpilogueParams::ActivationFunctor; - using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp; - - constexpr bool PerColBias = EpilogueParams::PerColumnBias; - constexpr bool IsScalingAndAmaxOutputNeeded = - cute::is_same_v or - cute::is_same_v; - - constexpr bool IsScalingAndAmaxAuxOutputNeeded = - cute::is_same_v or - cute::is_same_v; - - constexpr bool IsReLUAuxNeeded = - (cute::is_same_v> or - cute::is_same_v>) and - cute::is_same_v; - constexpr bool IsClamp = - cute::is_same_v>; - - constexpr bool IsBackpropFusion = - cute::is_same_v> or - cute::is_same_v>; - - // Input related converter - NumericConverter accumulator_converter; - NumericConverter source_converter; - NumericConverter bias_converter; - [[maybe_unused]] NumericConverter aux_source_converter; - - // Scale related converter - NumericConverter scale_converter; - NumericConverter scaling_factor_converter; - - // Abs max converter - [[maybe_unused]] NumericConverter abs_max_output_converter; - - // Output related converter - NumericConverter destination_converter; - [[maybe_unused]] NumericConverter aux_destination_converter; - NumericConverter dBias_converter; - - // Epilogue operations - multiply_add epilogue_fma; - multiplies mul; - plus add; - - // Activation operation - ActivationFunctor activation; - - // Bias binary operation - BiasBinaryOp bias_op; - - // Do conversion - ElementCompute converted_alpha = scale_converter(epilogue_params.alpha); - ElementCompute converted_beta = scale_converter(epilogue_params.beta); - ElementCompute converted_scale_a = scaling_factor_converter(epilogue_params.scale_a); - ElementCompute converted_scale_b = scaling_factor_converter(epilogue_params.scale_b); - ElementCompute converted_scale_c = scaling_factor_converter(epilogue_params.scale_c); - ElementCompute converted_scale_d = scaling_factor_converter(epilogue_params.scale_d); - ElementCompute converted_scale_aux = scaling_factor_converter(epilogue_params.scale_aux); - - // Init local var - [[maybe_unused]] ElementCompute local_abs_max_output = ElementCompute(0); - [[maybe_unused]] ElementCompute local_abs_max_aux_output = ElementCompute(0); - - converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b)); - converted_beta = mul(converted_beta, converted_scale_c); - - ElementCompute inter_accum[kBlockM][kBlockN]; - - for (int m_b = 0; m_b < kBlockM; ++m_b) { - ElementCompute local_dBias = ElementCompute(0); - - for (int n_b = 0; n_b < kBlockN; ++n_b) { - if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { - // Convert every type to ElementCompute first, do compute, convert to output type, write it out - ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]); - // per-row alpha - if (raw_pointer_cast(epilogue_params.Valpha.data())) { - converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b)); - } - ElementCompute output = mul(converted_alpha, converted_acc); - - if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) { - ElementCompute converted_bias = bias_converter(epilogue_params.Bias(PerColBias ? n + n_b : m + m_b)); - output = bias_op(output, converted_bias); - } - - if (raw_pointer_cast(epilogue_params.C.data())) { - ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l)); - // per-row beta - if (epilogue_params.Vbeta.data()) { - converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b)); - } - output = epilogue_fma(converted_beta, converted_src, output); - } - - if constexpr (IsBackpropFusion) { - ElementAux aux_input = ElementAux(0); - if (raw_pointer_cast(epilogue_params.Aux.data())) { - aux_input = epilogue_params.Aux(m + m_b, n + n_b, l); - } - - output = activation(output, aux_source_converter(aux_input)); - local_dBias = add(local_dBias, output); - } - else { - if (raw_pointer_cast(epilogue_params.Aux.data())) { - auto aux_output = output; - if constexpr (IsScalingAndAmaxAuxOutputNeeded) { - maximum_absolute_value_reduction amax_op; - local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output); - aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0)); - } - - if constexpr (IsReLUAuxNeeded) { - epilogue_params.Aux(m + m_b, n + n_b, l) = not (aux_output < 0) ? uint1b_t(1) : uint1b_t(0); - } else { - epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output); - } - } - - if constexpr (IsClamp) { // Treat Clamp as ReLU - output = activation(output, {0, std::numeric_limits::max()}); - } - else { - output = activation(output); - } - } - - if constexpr (IsScalingAndAmaxOutputNeeded) { - maximum_absolute_value_reduction amax_op; - local_abs_max_output = amax_op(local_abs_max_output, output); - output = epilogue_fma(converted_scale_d, output, ElementCompute(0)); - } - - inter_accum[m_b][n_b] = ElementCompute(output); - } - } // n_b - - if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n < cute::size<1>(epilogue_params.D.layout())) { - if (raw_pointer_cast(epilogue_params.Bias.data()) && IsBackpropFusion) { - ElementCompute converted_dBias = bias_converter(epilogue_params.Bias(m + m_b)); - local_dBias = add(local_dBias, converted_dBias); - epilogue_params.Bias(m + m_b) = dBias_converter(local_dBias); - } - } - } // m_b - for (int m_b = 0; m_b < kBlockM; ++m_b) { - for (int n_b = 0; n_b < kBlockN; ++n_b) { - if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { - epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]); - } - } - } - -#if defined(_OPENMP) - #pragma omp critical(Abs_Max_Data_Update) -#endif - { - if constexpr (IsScalingAndAmaxOutputNeeded) { - if (epilogue_params.abs_max_D) { - *epilogue_params.abs_max_D = maximum_with_nan_propogation{}( - *epilogue_params.abs_max_D, abs_max_output_converter(local_abs_max_output)); - } - } - - if constexpr (IsScalingAndAmaxAuxOutputNeeded) { - if (epilogue_params.abs_max_Aux) { - *epilogue_params.abs_max_Aux = maximum_with_nan_propogation{}( - *epilogue_params.abs_max_Aux, abs_max_output_converter(local_abs_max_aux_output)); - } - } - } -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GEMM - General Matrix-Matrix contraction without conjugation options -template < - class MainloopParams, - class EpilogueParams -> -void Gemm3x( - MainloopParams const& mainloop_params, - EpilogueParams const& epilogue_params) -{ - using namespace cute; - - static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{})); - static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{})); - static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{})); - static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "Only Rank3 Tensors (M, K, Batch_Count) " - "with Batchmode are supported"); - // Lower the Matrix-Multiplication with Blockwise scaling (Gemm3x) to a Tensor Contraction (Gett). - Gett(mainloop_params, epilogue_params); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // cutlass::reference::host - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h deleted file mode 100644 index 0bf90a4163..0000000000 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h +++ /dev/null @@ -1,518 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Reference implementation for GETT in host-side code. -*/ - -#pragma once - -///////////////////////////////////////////////////////////////////////////////////////////////// -#include "cutlass/gemm/gemm.h" -#include "cutlass/complex.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/epilogue/thread/activation.h" -#include "cutlass/relatively_equal.h" -#include -#include "cute/tensor.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::reference::host { - -template -struct ElementTraits { - using type = T; -}; - -template -struct ElementTraits().get()), void> > > { - using type = decltype(std::declval().get()); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template< - class ElementAccumulator_, - class TensorA_, // (M, K, L) - class TensorB_, // (N, K, L) - class TensorScaleA_, // (m, k, L) - class TensorScaleB_, // (n, k, L) - class TileShape_ -> -struct GettMainloopParams { - using ElementAccumulator = ElementAccumulator_; - using TensorA = TensorA_; - using TensorB = TensorB_; - using EngineA = typename TensorA::engine_type; - using LayoutA = typename TensorA::layout_type; - using EngineB = typename TensorB::engine_type; - using LayoutB = typename TensorB::layout_type; - - using TensorScaleA = TensorScaleA_; - using TensorScaleB = TensorScaleB_; - using TileShape = TileShape_; - using EngineScaleA = typename TensorScaleA::engine_type; - using EngineScaleB = typename TensorScaleB::engine_type; - - TensorA A{}; - TensorB B{}; - TensorScaleA ScaleA{}; - TensorScaleB ScaleB{}; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -template< - class ElementScalar_, - class ElementScalingFactor_, - class ElementAccumulator_, - class ElementCompute_, - class TensorC_, // (M, N, L) - class TensorD_, // (M, N, L) - class VectorBias_ = TensorD_, // (M, 1) - class TensorAux_ = TensorD_, // (M, N, L) - class VectorAlpha_ = TensorD_, // (M, 1) - class VectorBeta_ = VectorAlpha_, // (M, 1) - class ActivationFunctor_ = cutlass::epilogue::thread::Identity, - class BiasBinaryOp_ = cutlass::plus, - bool PerColumnBias_ = false -> -struct GettEpilogueParams { - using ElementScalar = ElementScalar_; - using ElementScalingFactor = ElementScalingFactor_; - using ElementAccumulator = ElementAccumulator_; - using ElementCompute = ElementCompute_; - using TensorC = TensorC_; - using TensorD = TensorD_; - using TensorAux = TensorAux_; - using VectorBias = VectorBias_; - using VectorAlpha = VectorAlpha_; - using VectorBeta = VectorBeta_; - using ActivationFunctor = ActivationFunctor_; - using BiasBinaryOp = BiasBinaryOp_; - - using EngineC = typename TensorC::engine_type; - using LayoutC = typename TensorC::layout_type; - using EngineD = typename TensorD::engine_type; - using LayoutD = typename TensorD::layout_type; - static constexpr bool PerColumnBias = PerColumnBias_; - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - - TensorC C{}; - TensorD D{}; - VectorBias Bias{}; - TensorAux Aux{}; - VectorAlpha Valpha{}; - VectorBeta Vbeta{}; - ElementCompute st = ElementCompute(1); - - ElementAccumulator* abs_max_D = nullptr; - ElementAccumulator* abs_max_Aux = nullptr; - - ElementScalingFactor scale_a = ElementScalingFactor(1); - ElementScalingFactor scale_b = ElementScalingFactor(1); - ElementScalingFactor scale_c = ElementScalingFactor(1); - ElementScalingFactor scale_d = ElementScalingFactor(1); - ElementScalingFactor scale_aux = ElementScalingFactor(1); - - bool beta_per_channel_scaling = false; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GETT - General Tensor-Tensor contraction reference kernel with Groupwise scaling -template < - class MainloopParams, - class EpilogueParams -> -void Gett( - MainloopParams const& mainloop_params, - EpilogueParams const& epilogue_params) -{ - - static int constexpr kBlockM = cute::get<0>(typename MainloopParams::TileShape{}); - static int constexpr kBlockN = cute::get<1>(typename MainloopParams::TileShape{}); - // printf("mainloop_params.ScaleA.layout()"); cute::print(mainloop_params.ScaleA.layout()); printf("\n"); - // printf("mainloop_params.ScaleB.layout()"); cute::print(mainloop_params.ScaleB.layout()); printf("\n"); - -#if defined(_OPENMP) - #pragma omp parallel for collapse(3) -#endif - for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { - for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { - for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { - typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN]; - gett_mainloop(mainloop_params, m, n, l, acc); - gett_epilogue(epilogue_params, m, n, l, acc); - } - } - } -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GETT - Mainloop -template -void gett_mainloop( - MainloopParams const& mainloop_params, - int64_t m, - int64_t n, - int64_t l, - ElementAccumulator (&acc)[kBlockM][kBlockN]) -{ - - static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B"); - static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B"); - - using cute::raw_pointer_cast; - - using ElementA = typename ElementTraits::type; - using ElementB = typename ElementTraits::type; - using ElementBlockScaleA = typename ElementTraits::type; - using ElementBlockScaleB = typename ElementTraits::type; - - using RingOp = multiply_add; - RingOp fma_op; - - multiplies scale_op; - - static int constexpr kBlockK = cute::get<2>(typename MainloopParams::TileShape{});; - - // Tempo accumulators to seperate blockwise accumulation - typename MainloopParams::ElementAccumulator acc_temp[kBlockM][kBlockN]; - - // Zero out accumulators - for (int m_b = 0; m_b < kBlockM; ++m_b) { - for (int n_b = 0; n_b < kBlockN; ++n_b) { - acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity - acc_temp[m_b][n_b] = ElementAccumulator(0); - } - } - - const int M = cute::size<0>(mainloop_params.A.layout()); - const int N = cute::size<0>(mainloop_params.B.layout()); - const int ScaleGranularityM = M / cute::size<0>(mainloop_params.ScaleA); - const int ScaleGranularityN = N / cute::size<0>(mainloop_params.ScaleB); - assert(ScaleGranularityM && M % ScaleGranularityM == 0 - && "ScaleGranularityM must divide M"); - assert(ScaleGranularityN && N % ScaleGranularityN == 0 - && "ScaleGranularityN must divide N"); - - cute::Tensor blockscale_A = domain_offset( - make_coord(m / ScaleGranularityM, _0{}), mainloop_params.ScaleA(_, _, l)); - cute::Tensor blockscale_B = domain_offset( - make_coord(n / ScaleGranularityN, _0{}), mainloop_params.ScaleB(_, _, l)); - - // Compute on this k-block - for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) { - - // Load Blockwise scaling factor from blockscale Tensors for B - int64_t block_k = k / kBlockK; - cute::Tensor scale_a = blockscale_A(_, block_k); - cute::Tensor scale_b = blockscale_B(_, block_k); - - // Load A - ElementAccumulator a_frag[kBlockM]; - for (int m_b = 0; m_b < kBlockM; ++m_b) { - if (m + m_b < cute::size<0>(mainloop_params.A.layout())) { - // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. - a_frag[m_b] = static_cast(ElementA(mainloop_params.A(m + m_b, k, l))); - } else { - a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity - } - } - - // Load B - ElementAccumulator b_frag[kBlockN]; - for (int n_b = 0; n_b < kBlockN; ++n_b) { - if (n + n_b < cute::size<0>(mainloop_params.B.layout())) { - // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. - b_frag[n_b] = static_cast(ElementB(mainloop_params.B(n + n_b, k, l))); - } else { - b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity - } - } - - int m_size = std::min(static_cast(kBlockM), cute::size<0>(mainloop_params.A.layout()) - m); - int n_size = std::min(static_cast(kBlockN), cute::size<0>(mainloop_params.B.layout()) - n); - - // do compute - for (int m_b = 0; m_b < m_size; ++m_b) { - for (int n_b = 0; n_b < n_size; ++n_b) { - acc_temp[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc_temp[m_b][n_b]); - } - } - - // Apply Groupwise-scaling at kBlockK boundary - // (a) Apply group and block scaling factors on the partial accumulated results (acc_temp) at the kBlocK boundary - // (b) Zero-out partial temporary (acc_temp), - // (c) Update permanent (accu) - if ((k+1) % kBlockK == 0) { - for (int m_b = 0; m_b < m_size; ++m_b) { - auto scale_a_m_b = scale_a[m_b / ScaleGranularityM]; - for (int n_b = 0; n_b < n_size; ++n_b) { - auto scale_b_n_b = scale_b[n_b / ScaleGranularityN]; - ElementAccumulator blockwise_scaled_accum = acc_temp[m_b][n_b] * scale_a_m_b * scale_b_n_b; - acc[m_b][n_b] = blockwise_scaled_accum + acc[m_b][n_b]; - acc_temp[m_b][n_b] = ElementAccumulator(0); - } - } - } - - } -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GETT - Epilogue -template -void gett_epilogue( - EpilogueParams const& epilogue_params, - int64_t m, - int64_t n, - int64_t l, - ElementAccumulator (&acc)[kBlockM][kBlockN]) -{ - static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B"); - static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B"); - - using cute::raw_pointer_cast; - - using ElementCompute = typename EpilogueParams::ElementCompute; - using ElementC = typename EpilogueParams::TensorC::value_type; - using ElementD = typename EpilogueParams::TensorD::value_type; - using ElementAux = typename EpilogueParams::TensorAux::value_type; - using ElementBias = typename EpilogueParams::VectorBias::value_type; - using ElementScalar = typename EpilogueParams::ElementScalar; - using ElementScalingFactor = typename EpilogueParams::ElementScalingFactor; - using ActivationFunctor = typename EpilogueParams::ActivationFunctor; - using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp; - - constexpr bool PerColBias = EpilogueParams::PerColumnBias; - constexpr bool IsScalingAndAmaxOutputNeeded = - cute::is_same_v or - cute::is_same_v; - - constexpr bool IsScalingAndAmaxAuxOutputNeeded = - cute::is_same_v or - cute::is_same_v; - - constexpr bool IsReLUAuxNeeded = - (cute::is_same_v> or - cute::is_same_v>) and - cute::is_same_v; - constexpr bool IsClamp = - cute::is_same_v>; - - constexpr bool IsBackpropFusion = - cute::is_same_v> or - cute::is_same_v>; - - // Input related converter - NumericConverter accumulator_converter; - NumericConverter source_converter; - NumericConverter bias_converter; - [[maybe_unused]] NumericConverter aux_source_converter; - - // Scale related converter - NumericConverter scale_converter; - NumericConverter scaling_factor_converter; - - // Abs max converter - [[maybe_unused]] NumericConverter abs_max_output_converter; - - // Output related converter - NumericConverter destination_converter; - [[maybe_unused]] NumericConverter aux_destination_converter; - NumericConverter dBias_converter; - - // Epilogue operations - multiply_add epilogue_fma; - multiplies mul; - plus add; - - // Activation operation - ActivationFunctor activation; - - // Bias binary operation - BiasBinaryOp bias_op; - - // Do conversion - ElementCompute converted_alpha = scale_converter(epilogue_params.alpha); - ElementCompute converted_beta = scale_converter(epilogue_params.beta); - ElementCompute converted_scale_a = scaling_factor_converter(epilogue_params.scale_a); - ElementCompute converted_scale_b = scaling_factor_converter(epilogue_params.scale_b); - ElementCompute converted_scale_c = scaling_factor_converter(epilogue_params.scale_c); - ElementCompute converted_scale_d = scaling_factor_converter(epilogue_params.scale_d); - ElementCompute converted_scale_aux = scaling_factor_converter(epilogue_params.scale_aux); - - // Init local var - [[maybe_unused]] ElementCompute local_abs_max_output = ElementCompute(0); - [[maybe_unused]] ElementCompute local_abs_max_aux_output = ElementCompute(0); - - converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b)); - converted_beta = mul(converted_beta, converted_scale_c); - - ElementCompute inter_accum[kBlockM][kBlockN]; - - for (int m_b = 0; m_b < kBlockM; ++m_b) { - ElementCompute local_dBias = ElementCompute(0); - - for (int n_b = 0; n_b < kBlockN; ++n_b) { - if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { - // Convert every type to ElementCompute first, do compute, convert to output type, write it out - ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]); - // per-row alpha - if (raw_pointer_cast(epilogue_params.Valpha.data())) { - converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b)); - } - ElementCompute output = mul(converted_alpha, converted_acc); - - if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) { - ElementCompute converted_bias = bias_converter(epilogue_params.Bias(PerColBias ? n + n_b : m + m_b)); - output = bias_op(output, converted_bias); - } - - if (raw_pointer_cast(epilogue_params.C.data())) { - ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l)); - // per-row beta - if (epilogue_params.Vbeta.data()) { - converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b)); - } - output = epilogue_fma(converted_beta, converted_src, output); - } - - if constexpr (IsBackpropFusion) { - ElementAux aux_input = ElementAux(0); - if (raw_pointer_cast(epilogue_params.Aux.data())) { - aux_input = epilogue_params.Aux(m + m_b, n + n_b, l); - } - - output = activation(output, aux_source_converter(aux_input)); - local_dBias = add(local_dBias, output); - } - else { - if (raw_pointer_cast(epilogue_params.Aux.data())) { - auto aux_output = output; - if constexpr (IsScalingAndAmaxAuxOutputNeeded) { - maximum_absolute_value_reduction amax_op; - local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output); - aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0)); - } - - if constexpr (IsReLUAuxNeeded) { - epilogue_params.Aux(m + m_b, n + n_b, l) = not (aux_output < 0) ? uint1b_t(1) : uint1b_t(0); - } else { - epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output); - } - } - - if constexpr (IsClamp) { // Treat Clamp as ReLU - output = activation(output, {0, std::numeric_limits::max()}); - } - else { - output = activation(output); - } - } - - if constexpr (IsScalingAndAmaxOutputNeeded) { - maximum_absolute_value_reduction amax_op; - local_abs_max_output = amax_op(local_abs_max_output, output); - output = epilogue_fma(converted_scale_d, output, ElementCompute(0)); - } - - inter_accum[m_b][n_b] = ElementCompute(output); - } - } // n_b - - if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n < cute::size<1>(epilogue_params.D.layout())) { - if (raw_pointer_cast(epilogue_params.Bias.data()) && IsBackpropFusion) { - ElementCompute converted_dBias = bias_converter(epilogue_params.Bias(m + m_b)); - local_dBias = add(local_dBias, converted_dBias); - epilogue_params.Bias(m + m_b) = dBias_converter(local_dBias); - } - } - } // m_b - for (int m_b = 0; m_b < kBlockM; ++m_b) { - for (int n_b = 0; n_b < kBlockN; ++n_b) { - if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { - epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]); - } - } - } - -#if defined(_OPENMP) - #pragma omp critical(Abs_Max_Data_Update) -#endif - { - if constexpr (IsScalingAndAmaxOutputNeeded) { - if (epilogue_params.abs_max_D) { - *epilogue_params.abs_max_D = maximum_with_nan_propogation{}( - *epilogue_params.abs_max_D, abs_max_output_converter(local_abs_max_output)); - } - } - - if constexpr (IsScalingAndAmaxAuxOutputNeeded) { - if (epilogue_params.abs_max_Aux) { - *epilogue_params.abs_max_Aux = maximum_with_nan_propogation{}( - *epilogue_params.abs_max_Aux, abs_max_output_converter(local_abs_max_aux_output)); - } - } - } -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GEMM - General Matrix-Matrix contraction without conjugation options -template < - class MainloopParams, - class EpilogueParams -> -void Gemm3x( - MainloopParams const& mainloop_params, - EpilogueParams const& epilogue_params) -{ - using namespace cute; - - static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{})); - static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{})); - static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{})); - static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "Only Rank3 Tensors (M, K, Batch_Count) " - "with Batchmode are supported"); - // Lower the Matrix-Multiplication with Groupwise scaling (Gemm3x) to a Tensor Contraction (Gett). - Gett(mainloop_params, epilogue_params); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // cutlass::reference::host - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling.cu b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling.cu index d20bad5827..d14360deb6 100644 --- a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling.cu +++ b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling.cu @@ -87,11 +87,11 @@ #include "cutlass/util/reference/host/tensor_compare.h" #include "cutlass/util/reference/host/tensor_norm.h" #include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" // Includes from examples directory #include "helper.h" #include "hopper_fp8_commandline.hpp" -#include "reference/host/gemm_with_groupwise_scaling.h" using namespace cute; @@ -128,54 +128,29 @@ using ElementAccumulator = float; // E using ElementBlockScale = float; // Element type for blockscaling during accumulation using ElementCompute = float; // Element type for epilogue computation -using TileShape_ = Shape<_128,_128,_128>; // This one is just to make the compiler happy with verify()... - -// ScaleGranularity{M,N}: number of {rows in A}/{columns in B} that share the same scaling factor -// Given TileShape = Shape<_128,_128,_128>: -// ScaleGranularityM == 128 and ScaleGranularityN == 128 --> 2Dx2D (the shape of the scaling factor) -// ScaleGranularityM == 1 and ScaleGranularityN == 128 --> 1Dx2D scaling -// ScaleGranularityM == 128 and ScaleGranularityN == 1 --> 2Dx1D scaling -// ScaleGranularityM == 1 and ScaleGranularityN == 1 --> 1Dx1D scaling -template -struct GroupScaleConfig { - using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature - using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag - using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size - using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster - - static constexpr int ScaleGranularityM = ScaleGranularityM_; - static constexpr int ScaleGranularityN = ScaleGranularityN_; - static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; - static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; - - static_assert(size<0>(TileShape{}) == ScaleGranularityM * ScaleMsPerTile, - "FP8 scaling granularity must evenly divide tile shape along M."); - static_assert(size<1>(TileShape{}) == ScaleGranularityN * ScaleNsPerTile, - "FP8 scaling granularity must evenly divide tile shape along N."); - - using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum; - using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; - using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; - using FusionOperation = cutlass::epilogue::fusion::LinearCombination; -}; +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster + +constexpr int ScaleGranularityM = 1; +constexpr int ScaleGranularityN = 128; +constexpr int ScaleGranularityK = 128; + +constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; +constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; + +using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig; -using GroupScale1D1DConfig = GroupScaleConfig< 1, 1>; -using GroupScale1D2DConfig = GroupScaleConfig< 1, size<1>(TileShape_{})>; -using GroupScale2D1DConfig = GroupScaleConfig(TileShape_{}), 1>; -using GroupScale2D2DConfig = GroupScaleConfig(TileShape_{}), size<1>(TileShape_{})>; - -template -struct GroupScaleGemm { - using ArchTag = typename ScheduleConfig::ArchTag; - using OperatorClass = typename ScheduleConfig::OperatorClass; - using TileShape = typename ScheduleConfig::TileShape; - using ClusterShape = typename ScheduleConfig::ClusterShape; - using KernelSchedule = typename ScheduleConfig::KernelSchedule; - using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; - using EpilogueTileType = typename ScheduleConfig::EpilogueTileType; - using FusionOperation = typename ScheduleConfig::FusionOperation; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< +using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand +using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand + +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum; +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; +using FusionOperation = cutlass::epilogue::fusion::LinearCombination; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, @@ -186,10 +161,10 @@ struct GroupScaleGemm { FusionOperation >::CollectiveOp; - using CollectiveMainloopWithGroupWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder< +using CollectiveMainloopWithGroupWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, - ElementA, LayoutA *, AlignmentA, - ElementB, LayoutB *, AlignmentB, + ElementA, cute::tuple, AlignmentA, + ElementB, cute::tuple, AlignmentB, ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout< @@ -198,29 +173,23 @@ struct GroupScaleGemm { KernelSchedule >::CollectiveOp; - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - ProblemShape, - CollectiveMainloopWithGroupWiseScaling, - CollectiveEpilogue +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloopWithGroupWiseScaling, + CollectiveEpilogue >; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -}; +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -using GroupScale1D1DGemm = GroupScaleGemm; -using GroupScale1D2DGemm = GroupScaleGemm; -using GroupScale2D1DGemm = GroupScaleGemm; -using GroupScale2D2DGemm = GroupScaleGemm; // Extract information from Gemm kernel. -using EpilogueOutputOp = typename GroupScale1D1DGemm::Gemm::EpilogueOutputOp; +using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; using ElementScalar = typename EpilogueOutputOp::ElementScalar; -using ActivationFunctor = typename EpilogueOutputOp::ActivationFn; -using StrideA = typename GroupScale1D1DGemm::Gemm::GemmKernel::InternalStrideA; -using StrideB = typename GroupScale1D1DGemm::Gemm::GemmKernel::InternalStrideB; -using StrideC = typename GroupScale1D1DGemm::Gemm::GemmKernel::InternalStrideC; -using StrideD = typename GroupScale1D1DGemm::Gemm::GemmKernel::InternalStrideD; +using StrideA = typename Gemm::GemmKernel::InternalStrideA; +using StrideB = typename Gemm::GemmKernel::InternalStrideB; +using StrideC = typename Gemm::GemmKernel::InternalStrideC; +using StrideD = typename Gemm::GemmKernel::InternalStrideD; static_assert(cute::is_same_v, "ElementAccumulator and ElementBlockScale should be same datatype"); @@ -240,6 +209,8 @@ std::vector stride_A_host; std::vector stride_B_host; std::vector stride_C_host; std::vector stride_D_host; +std::vector layout_SFA_host; +std::vector layout_SFB_host; std::vector alpha_host; std::vector beta_host; @@ -265,6 +236,8 @@ cutlass::DeviceAllocation stride_A; cutlass::DeviceAllocation stride_B; cutlass::DeviceAllocation stride_C; cutlass::DeviceAllocation stride_D; +cutlass::DeviceAllocation layout_SFA; +cutlass::DeviceAllocation layout_SFB; cutlass::DeviceAllocation alpha_device; cutlass::DeviceAllocation beta_device; @@ -343,10 +316,6 @@ bool initialize_block( template void allocate(const OptionType &options) { - using TileShape = typename OptionType::GroupScaleConfig::TileShape; - const int ScaleMsPerTile = OptionType::GroupScaleConfig::ScaleMsPerTile; - const int ScaleNsPerTile = OptionType::GroupScaleConfig::ScaleNsPerTile; - int64_t total_elements_A = 0; int64_t total_elements_B = 0; int64_t total_elements_C = 0; @@ -372,10 +341,8 @@ void allocate(const OptionType &options) { auto N = get<1>(problem); auto K = get<2>(problem); - auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(problem), TileShape{}))); - auto groupscale_m = cute::get<0>(blockscale_shape) * ScaleMsPerTile; // We need to pad along M in scale tensor of A to prevent illegal memory access. - auto groupscale_n = cute::get<1>(blockscale_shape) * ScaleNsPerTile; // We need to pad along N in scale tensor of A to prevent illegal memory access. - auto blockscale_k = cute::get<2>(blockscale_shape); + auto group_layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1)); + auto group_layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1)); offset_A.push_back(total_elements_A); offset_B.push_back(total_elements_B); @@ -388,8 +355,8 @@ void allocate(const OptionType &options) { int64_t elements_B = K * N; int64_t elements_C = M * N; int64_t elements_D = M * N; - int64_t elements_blockscale_A = groupscale_m * blockscale_k; - int64_t elements_blockscale_B = groupscale_n * blockscale_k; + int64_t elements_blockscale_A = size(filter_zeros(group_layout_SFA)); + int64_t elements_blockscale_B = size(filter_zeros(group_layout_SFB)); total_elements_A += elements_A; total_elements_B += elements_B; @@ -402,6 +369,8 @@ void allocate(const OptionType &options) { stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1})); stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1})); + layout_SFA_host.push_back(group_layout_SFA); + layout_SFB_host.push_back(group_layout_SFB); } @@ -477,6 +446,12 @@ void initialize(const OptionType &options) { stride_D.reset(options.groups); stride_D.copy_from_host(stride_D_host.data()); + layout_SFA.reset(options.groups); + layout_SFA.copy_from_host(layout_SFA_host.data()); + + layout_SFB.reset(options.groups); + layout_SFB.copy_from_host(layout_SFB_host.data()); + alpha_device.reset(options.groups); alpha_device.copy_from_host(ptr_alpha_host.data()); beta_device.reset(options.groups); @@ -500,14 +475,14 @@ GemmArguments args_from_options(const OptionType &options, bool host_problem_sha // Change device_id to another value if you are running on a machine with multiple GPUs and wish // to use a GPU other than that with device ID 0. int device_id = 0; - cutlass::KernelHardwareInfo kernel_hw_info = cutlass::KernelHardwareInfo::make_kernel_hardware_info(device_id); + cutlass::KernelHardwareInfo kernel_hw_info = cutlass::KernelHardwareInfo::make_kernel_hardware_info(device_id); GemmArguments arguments{ cutlass::gemm::GemmUniversalMode::kGrouped, {options.groups, problem_sizes.get(), host_problem_shapes_available ? options.problem_sizes_host.data() : (decltype(options.problem_sizes_host.data())) nullptr}, {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(), - ptr_blockscale_A.get(), - ptr_blockscale_B.get() + ptr_blockscale_A.get(), layout_SFA.get(), + ptr_blockscale_B.get(), layout_SFB.get() }, { {}, // epilogue.thread @@ -577,12 +552,6 @@ bool verify(const OptionType &options) { // Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape auto [m, n, k] = options.problem_sizes_host.at(group_idx); auto gemm_problem_shape = cute::make_shape(m, n, k); - auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape_{}))); - auto blockscale_m = cute::get<0>(blockscale_shape); - auto blockscale_n = cute::get<1>(blockscale_shape); - auto blockscale_k = cute::get<2>(blockscale_shape); - auto groupscale_m = blockscale_m * OptionType::GroupScaleConfig::ScaleMsPerTile; - auto groupscale_n = blockscale_n * OptionType::GroupScaleConfig::ScaleNsPerTile; // Create instantiation for device reference gemm kernel auto A = cute::make_tensor(block_A_host.data() + offset_A.at(group_idx), @@ -610,32 +579,20 @@ bool verify(const OptionType &options) { ) ); - auto blockscale_A = cute::make_tensor(blockscale_block_A_host.data() + offset_blockscale_A.at(group_idx), - cute::make_layout( - cute::make_shape(groupscale_m, blockscale_k, 1), - cute::make_stride(1, groupscale_m, groupscale_m * blockscale_k) - ) - ); - auto blockscale_B = cute::make_tensor(blockscale_block_B_host.data() + offset_blockscale_B.at(group_idx), - cute::make_layout( - cute::make_shape(groupscale_n, blockscale_k, 1), - cute::make_stride(1, groupscale_n, groupscale_n * blockscale_k) - ) - ); + auto SFA = cute::make_tensor(blockscale_block_A_host.data() + offset_blockscale_A.at(group_idx), + layout_SFA_host.at(group_idx)); + auto SFB = cute::make_tensor(blockscale_block_B_host.data() + offset_blockscale_B.at(group_idx), + layout_SFB_host.at(group_idx)); using unused_t = decltype(D); - cutlass::reference::host::GettMainloopParams< + cutlass::reference::host::GettBlockScalingMainloopParams< ElementAccumulator, - decltype(A), + decltype(A), + decltype(SFA), decltype(B), - decltype(blockscale_A), - decltype(blockscale_B), - TileShape_ - > mainloop_params{ - A, B, // Operand Tensors - blockscale_A, blockscale_B // Groupwise scaling Tensors - }; + decltype(SFB) + > mainloop_params{A, SFA, B, SFB}; cutlass::reference::host::GettEpilogueParams< ElementScalar, @@ -647,8 +604,7 @@ bool verify(const OptionType &options) { unused_t, // bias unused_t, // Aux unused_t, // valpha - unused_t, // vbeta - ActivationFunctor + unused_t // vbeta > epilogue_params; epilogue_params.C = C; @@ -679,15 +635,9 @@ bool verify(const OptionType &options) { } /// Execute a given example GEMM computation -template +template int run(OptionType &options, bool host_problem_shapes_available = true) { - using TileShape = typename OptionType::GroupScaleConfig::TileShape; - const int ScaleGranularityM = OptionType::GroupScaleConfig::ScaleGranularityM; - const int ScaleGranularityN = OptionType::GroupScaleConfig::ScaleGranularityN; - const int ScaleMsPerTile = OptionType::GroupScaleConfig::ScaleMsPerTile; - const int ScaleNsPerTile = OptionType::GroupScaleConfig::ScaleNsPerTile; - allocate(options); initialize(options); @@ -797,18 +747,12 @@ int main(int argc, char const **args) { // Parse options // - Options options_1d1d; - Options options_1d2d; - Options options_2d1d; - Options options_2d2d; + Options options; - options_1d1d.parse(argc, args); - options_1d2d.parse(argc, args); - options_2d1d.parse(argc, args); - options_2d2d.parse(argc, args); + options.parse(argc, args); - if (options_1d1d.help) { - options_1d1d.print_usage(std::cout) << std::endl; + if (options.help) { + options.print_usage(std::cout) << std::endl; return 0; } @@ -816,22 +760,10 @@ int main(int argc, char const **args) { // Evaluate CUTLASS kernels // - auto run_tests = [&] (bool host_problem_shapes_available = true) { - std::cout << "Grouped GEMM kernel with 1D1D group scale" << std::endl; - run(options_1d1d, host_problem_shapes_available); - std::cout << "Grouped GEMM kernel with 1D2D group scale" << std::endl; - run(options_1d2d, host_problem_shapes_available); - std::cout << "Grouped GEMM kernel with 2D1D group scale" << std::endl; - run(options_2d1d, host_problem_shapes_available); - std::cout << "Grouped GEMM kernel with 2D2D group scale" << std::endl; - run(options_2d2d, host_problem_shapes_available); - std::cout << std::endl; - }; - std::cout << "Running tests with host problem shapes:" << std::endl; - run_tests(true); + run(options, true); std::cout << "Running tests without host problem shapes:" << std::endl; - run_tests(false); + run(options, false); #endif diff --git a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.cu b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.cu new file mode 100644 index 0000000000..2ea42bbf58 --- /dev/null +++ b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.cu @@ -0,0 +1,781 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Grouped scale Hopper FP8 Grouped GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture + This example demonstrates a grouped scaled FP8 Grouped GEMM using the new CUTLASS 3.0. + APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows: + 1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA) + which are more efficient than the Ampere tensor core instructions. + 2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large + blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous + copies between thread blocks in a cluster. This example also showcases on-the-fly modification of TMA + descriptors to move between groups/problem_count (represented by groups). + 3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details). + 4. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the + CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can + improve performance. + 5. This example is tuned specifically for the sparse groups case, where the number of active groups (groups + with non-zero problem count) is much smaller than the total number of groups. + Examples: + $ ./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups \ + --m=2816 --n=3072 --k=16384 --save_aux=false --save_amax=false \ + --raster=h --swizzle=2 --benchmark=./test_benchmark.txt + + Where the test_benchmark.txt may look as such: + 0 256x512x128 + 1 256x512x512 + 2 512x256x128 + 3 256x256x128 + 4 256x512x1024 + 5 1024x512x128 and so on +*/ + +#include +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" + +// Includes from examples directory +#include "helper.h" +#include "hopper_fp8_commandline.hpp" + +using namespace cute; + +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C matrix configuration +using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// D matrix configuration +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = AlignmentC; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementBlockScale = float; // Element type for blockscaling during accumulation +using ElementCompute = float; // Element type for epilogue computation + +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + +using TileShape = Shape<_128,_128,_128>; // This one is just to make the compiler happy with verify()... +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster + +static constexpr int ScaleGranularityM = 1; +static constexpr int ScaleGranularityN = 128; +static constexpr int ScaleGranularityK = 128; +static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; +static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; + +using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig; + +using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand +using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand + + +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum; +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; +using FusionOperation = cutlass::epilogue::fusion::LinearCombination; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + TileShape, ClusterShape, + EpilogueTileType, + ElementAccumulator, ElementCompute, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutD *, AlignmentD, + EpilogueSchedule, + FusionOperation +>::CollectiveOp; + +using CollectiveMainloopWithGroupWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, cute::tuple, AlignmentA, + ElementB, cute::tuple, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule +>::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloopWithGroupWiseScaling, + CollectiveEpilogue +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Extract information from Gemm kernel. +using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; +using ElementScalar = typename EpilogueOutputOp::ElementScalar; +using ActivationFunctor = typename EpilogueOutputOp::ActivationFn; + +using StrideA = typename Gemm::GemmKernel::InternalStrideA; +using StrideB = typename Gemm::GemmKernel::InternalStrideB; +using StrideC = typename Gemm::GemmKernel::InternalStrideC; +using StrideD = typename Gemm::GemmKernel::InternalStrideD; + +static_assert(cute::is_same_v, + "ElementAccumulator and ElementBlockScale should be same datatype"); + +/// Initialization + +cutlass::DeviceAllocation problem_sizes; + +std::vector offset_A; +std::vector offset_B; +std::vector offset_C; +std::vector offset_D; +std::vector offset_blockscale_A; +std::vector offset_blockscale_B; + +std::vector stride_A_host; +std::vector stride_B_host; +std::vector stride_C_host; +std::vector stride_D_host; +std::vector layout_SFA_host; +std::vector layout_SFB_host; + +std::vector alpha_host; +std::vector beta_host; + +uint64_t seed; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation blockscale_block_A; +cutlass::DeviceAllocation blockscale_block_B; + +cutlass::DeviceAllocation ptr_A; +cutlass::DeviceAllocation ptr_B; +cutlass::DeviceAllocation ptr_C; +cutlass::DeviceAllocation ptr_D; +cutlass::DeviceAllocation ptr_ref_D; +cutlass::DeviceAllocation ptr_blockscale_A; +cutlass::DeviceAllocation ptr_blockscale_B; + +cutlass::DeviceAllocation stride_A; +cutlass::DeviceAllocation stride_B; +cutlass::DeviceAllocation stride_C; +cutlass::DeviceAllocation stride_D; +cutlass::DeviceAllocation layout_SFA; +cutlass::DeviceAllocation layout_SFB; + +cutlass::DeviceAllocation alpha_device; +cutlass::DeviceAllocation beta_device; +cutlass::DeviceAllocation block_alpha; +cutlass::DeviceAllocation block_beta; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90GroupParams>::RasterOrderOptions; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + double gbps; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + double gbps = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), gbps(gbps), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023, + ScopeMin scope_min = std::nullopt, ScopeMax scope_max = std::nullopt) { + + double _scope_max, _scope_min; + int bits_input = cutlass::sizeof_bits::value; + if (bits_input == 1) { + _scope_max = 2; + _scope_min = 0; + } else if (bits_input <= 8) { + _scope_max = 2; + _scope_min = -2; + } else if (bits_input == 16) { + _scope_max = 5; + _scope_min = -5; + } else { + _scope_max = 8; + _scope_min = -8; + } + if constexpr (!std::is_same_v) { + _scope_max = scope_max; + } + if constexpr (!std::is_same_v) { + _scope_min = scope_min; + } + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, (Element) _scope_max, (Element) _scope_min, 0); + + return true; +} + +/// Allocates device-side data +template +void allocate(const OptionType &options) { + + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + int64_t total_elements_blockscale_A = 0; + int64_t total_elements_blockscale_B = 0; + + offset_A.clear(); + offset_B.clear(); + offset_C.clear(); + offset_D.clear(); + offset_blockscale_A.clear(); + offset_blockscale_B.clear(); + stride_A_host.clear(); + stride_B_host.clear(); + stride_C_host.clear(); + stride_D_host.clear(); + + for (int32_t i = 0; i < options.groups; ++i) { + + auto problem = options.problem_sizes_after_alignment_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + auto group_layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1)); + auto group_layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1)); + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + offset_blockscale_A.push_back(total_elements_blockscale_A); + offset_blockscale_B.push_back(total_elements_blockscale_B); + + int64_t elements_A = M * K; + int64_t elements_B = K * N; + int64_t elements_C = M * N; + int64_t elements_D = M * N; + int64_t elements_blockscale_A = size(filter_zeros(group_layout_SFA)); + int64_t elements_blockscale_B = size(filter_zeros(group_layout_SFB)); + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + total_elements_blockscale_A += elements_blockscale_A; + total_elements_blockscale_B += elements_blockscale_B; + + stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1})); + stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); + stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1})); + stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1})); + layout_SFA_host.push_back(group_layout_SFA); + layout_SFB_host.push_back(group_layout_SFB); + + } + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + block_alpha.reset(options.groups); + block_beta.reset(options.groups); + blockscale_block_A.reset(total_elements_blockscale_A); + blockscale_block_B.reset(total_elements_blockscale_B); +} + +/// Initialize operands to be used in the GEMM and reference GEMM +template +void initialize(const OptionType &options) { + + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_after_alignment_host.data()); + + std::vector ptr_A_host(options.groups); + std::vector ptr_B_host(options.groups); + std::vector ptr_C_host(options.groups); + std::vector ptr_D_host(options.groups); + std::vector ptr_alpha_host(options.groups); + std::vector ptr_beta_host(options.groups); + std::vector ptr_blockscale_A_host(options.groups); + std::vector ptr_blockscale_B_host(options.groups); + + alpha_host.clear(); + beta_host.clear(); + + for (int i = 0; i < options.groups; i++) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i); + ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i); + alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast((rand() % 5) + 1) : options.alpha); + beta_host.push_back((options.beta == FLT_MAX) ? static_cast(rand() % 5) : options.beta); + ptr_alpha_host.at(i) = block_alpha.get() + i; + ptr_beta_host.at(i) = block_beta.get() + i; + } + + ptr_A.reset(options.groups); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(ptr_D_host.data()); + + ptr_blockscale_A.reset(options.groups); + ptr_blockscale_A.copy_from_host(ptr_blockscale_A_host.data()); + + ptr_blockscale_B.reset(options.groups); + ptr_blockscale_B.copy_from_host(ptr_blockscale_B_host.data()); + + stride_A.reset(options.groups); + stride_A.copy_from_host(stride_A_host.data()); + + stride_B.reset(options.groups); + stride_B.copy_from_host(stride_B_host.data()); + + stride_C.reset(options.groups); + stride_C.copy_from_host(stride_C_host.data()); + + stride_D.reset(options.groups); + stride_D.copy_from_host(stride_D_host.data()); + + layout_SFA.reset(options.groups); + layout_SFA.copy_from_host(layout_SFA_host.data()); + + layout_SFB.reset(options.groups); + layout_SFB.copy_from_host(layout_SFB_host.data()); + + alpha_device.reset(options.groups); + alpha_device.copy_from_host(ptr_alpha_host.data()); + beta_device.reset(options.groups); + beta_device.copy_from_host(ptr_beta_host.data()); + + initialize_block(block_A, seed + 2022); + initialize_block(block_B, seed + 2023); + initialize_block(block_C, seed + 2024); + initialize_block(blockscale_block_A, seed + 2025, -1, 1); + initialize_block(blockscale_block_B, seed + 2026, -1, 1); + + block_alpha.copy_from_host(alpha_host.data()); + block_beta.copy_from_host(beta_host.data()); + +} + +/// Populates a Gemm::Arguments structure from the given commandline options +template +GemmArguments args_from_options(const OptionType &options, bool host_problem_shapes_available = true) +{ + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + int device_id = 0; + cutlass::KernelHardwareInfo kernel_hw_info = cutlass::KernelHardwareInfo::make_kernel_hardware_info(device_id); + + GemmArguments arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), host_problem_shapes_available ? options.problem_sizes_after_alignment_host.data() : (decltype(options.problem_sizes_after_alignment_host.data())) nullptr}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(), + ptr_blockscale_A.get(), layout_SFA.get(), + ptr_blockscale_B.get(), layout_SFB.get() + }, + { + {}, // epilogue.thread + ptr_C.get(), stride_C.get(), + ptr_D.get(), stride_D.get() + }, + kernel_hw_info + }; + + auto &fusion_args = arguments.epilogue.thread; + if (options.alpha != FLT_MAX && options.beta != FLT_MAX) { + // If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches. + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + // Single alpha and beta for all groups + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; + } + else { + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups. + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = alpha_device.get(); + fusion_args.beta_ptr_array = beta_device.get(); + // One alpha and beta per each group + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; + } + + arguments.scheduler.raster_order = options.raster; + // The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8) + arguments.scheduler.max_swizzle_size = options.swizzle; + + return arguments; +} + +template +bool verify(const OptionType &options) { + + // + // Compute reference output + // + + std::vector block_A_host(block_A.size()); + std::vector block_B_host(block_B.size()); + std::vector block_C_host(block_C.size()); + std::vector block_D_host_kernel(block_D.size()); + std::vector block_D_host_ref(block_D.size()); + std::vector blockscale_block_A_host(blockscale_block_A.size()); + std::vector blockscale_block_B_host(blockscale_block_B.size()); + + block_A.copy_to_host(block_A_host.data()); + block_B.copy_to_host(block_B_host.data()); + block_C.copy_to_host(block_C_host.data()); + block_D.copy_to_host(block_D_host_kernel.data()); + blockscale_block_A.copy_to_host(blockscale_block_A_host.data()); + blockscale_block_B.copy_to_host(blockscale_block_B_host.data()); + + bool passed = true; + for (int group_idx = 0; group_idx < options.groups; group_idx++) { + // Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape + auto [m, n, k] = options.problem_sizes_after_alignment_host.at(group_idx); + auto gemm_problem_shape = cute::make_shape(m, n, k); + + // Create instantiation for device reference gemm kernel + auto A = cute::make_tensor(block_A_host.data() + offset_A.at(group_idx), + cute::make_layout( + cute::make_shape(m, k, 1), + stride_A_host.at(group_idx) + ) + ); + auto B = cute::make_tensor(block_B_host.data() + offset_B.at(group_idx), + cute::make_layout( + cute::make_shape(n, k, 1), + stride_B_host.at(group_idx) + ) + ); + auto C = cute::make_tensor(block_C_host.data() + offset_C.at(group_idx), + cute::make_layout( + cute::make_shape(m, n, 1), + stride_C_host.at(group_idx) + ) + ); + auto D = cute::make_tensor(block_D_host_ref.data() + offset_D.at(group_idx), + cute::make_layout( + cute::make_shape(m, n, 1), + stride_D_host.at(group_idx) + ) + ); + + auto SFA = cute::make_tensor(blockscale_block_A_host.data() + offset_blockscale_A.at(group_idx), + layout_SFA_host.at(group_idx)); + auto SFB = cute::make_tensor(blockscale_block_B_host.data() + offset_blockscale_B.at(group_idx), + layout_SFB_host.at(group_idx)); + + using unused_t = decltype(D); + + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, + decltype(A), + decltype(SFA), + decltype(B), + decltype(SFB) + > mainloop_params{A, SFA, B, SFB}; + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D) + > epilogue_params; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = alpha_host.at(group_idx); + epilogue_params.beta = beta_host.at(group_idx); + + // get reference result + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + auto this_group_passed = std::equal( + // std::execution::par_unseq, + block_D_host_ref.data() + offset_D.at(group_idx), + block_D_host_ref.data() + offset_D.at(group_idx) + m * n, + block_D_host_kernel.data() + offset_D.at(group_idx) + ); + + passed &= this_group_passed; + +#if 0 + std::cout << "Group: " << group_idx << " M: " << m << " N: " << n << " K: " << k << " Status: " << this_group_passed << std::endl; +#endif + + } + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(OptionType &options, bool host_problem_shapes_available = true) +{ + + allocate(options); + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options, host_problem_shapes_available); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + result.gbps = options.template gbps(result.avg_runtime_ms / 1000.0); + + std::string raster = "Heuristic"; + + if (options.raster == RasterOrderOptions::AlongN) { + raster = "Along N"; + } + else if (options.raster == RasterOrderOptions::AlongM) { + raster = "Along M"; + } + + std::cout << " Problem Sizes, Alpha, Beta " << std::endl; + for (int32_t i = 0; i < options.groups; ++i) { + std::cout << " " << options.problem_sizes_host.at(i); + std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl; + } + std::cout << " Groups : " << options.groups << std::endl; + std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl; + std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl; + std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl; + std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) { + std::cerr << "This example requires CUDA 12.3 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + + run(options, true); + + std::cout << "Running tests without host problem shapes:" << std::endl; + run(options, false); + +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/CMakeLists.txt b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/CMakeLists.txt index f88b31674d..09d506dee1 100644 --- a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/CMakeLists.txt +++ b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/CMakeLists.txt @@ -59,3 +59,26 @@ cutlass_example_add_executable( TEST_SMALL TEST_SMALL_LARGE_GROUP ) + +# MSVC will fail to compile this example with the following error: +# fatal error C1083: Cannot open source file: : No such file or directory [...\examples\68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling\68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.vcxproj] +# This is a known issue and we are working on a fix. +if (NOT MSVC) + +cutlass_example_add_executable( + 68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups + 68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.cu + TEST_COMMAND_OPTIONS + TEST_RANDOM + TEST_RANDOM_LARGE_GROUP + TEST_EPILOGUE + TEST_EPILOGUE_LARGE_GROUP + TEST_EPILOGUE_OP + TEST_EPILOGUE_OP_LARGE_GROUP + TEST_FIXED + TEST_FIXED_LARGE_GROUP + TEST_SMALL + TEST_SMALL_LARGE_GROUP + ) + +endif() diff --git a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp index 3e425fe23e..19497176db 100644 --- a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp +++ b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp @@ -30,12 +30,11 @@ **************************************************************************************************/ // Command line options parsing -template +template struct Options { using RasterOrderOptions = _RasterOrderOptions; using ProblemShape = _ProblemShape; - using GroupScaleConfig = _GroupScaleConfig; bool help = false; @@ -43,6 +42,7 @@ struct Options { int iterations = 1000; int m = 1024, n = 512, k = 1024, groups = 10; std::string benchmark_path; + std::vector problem_sizes_after_alignment_host; std::vector problem_sizes_host; int const tma_alignment_bits = 128; int const alignment = tma_alignment_bits / cutlass::sizeof_bits::value; @@ -89,6 +89,7 @@ struct Options { // Decide how to initialize the problems if (!benchmark_path.empty()) { if (!benchmark_problems()) { + problem_sizes_after_alignment_host.clear(); problem_sizes_host.clear(); return; } @@ -105,8 +106,8 @@ struct Options { cmd.get_cmd_line_argument("n", cmd_line_n); cmd.get_cmd_line_argument("k", cmd_line_k); + problem_sizes_after_alignment_host.reserve(groups); problem_sizes_host.reserve(groups); - for (int i = groups; i > 0; i--) { int m = cmd_line_m; int n = cmd_line_n; @@ -120,6 +121,7 @@ struct Options { if (k < 1) { k = k_alignment * ((rand() % (32 * alignment / k_alignment)) + 1); } + problem_sizes_after_alignment_host.push_back({m, n, k}); problem_sizes_host.push_back({m, n, k}); } } @@ -142,7 +144,7 @@ struct Options { break; } - cutlass::gemm::GemmCoord extent; + cutlass::gemm::GemmCoord extent_after_alignment, extent; std::vector tokens; cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); @@ -150,23 +152,81 @@ struct Options { for (int i = 0; i < int(tokens.size()); ++i) { int x = std::atoi(tokens.at(i).c_str()); + extent.at(i) = x; // round up if (x % alignment) { x += (alignment - (x % alignment)); } - extent.at(i) = x; + extent_after_alignment.at(i) = x; } - if (extent.product()) { - problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()}); - } + problem_sizes_after_alignment_host.push_back({extent_after_alignment.m(), extent_after_alignment.n(), extent_after_alignment.k()}); + problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()}); } - groups = static_cast(problem_sizes_host.size()); + groups = static_cast(problem_sizes_after_alignment_host.size()); return true; } + /// Calculate memory bandwidth statistics + template + auto gbps(double runtime_s) const { + double total_read_bytes = 0; + double total_write_bytes = 0; + + // Calculate bytes read and written for each problem + for (int i = 0; i < groups; ++i) { + auto problem = problem_sizes_host.at(i); + auto M = cute::get<0>(problem); + auto N = cute::get<1>(problem); + auto K = cute::get<2>(problem); + + if (M > 0) { // Only count active problems + // Matrix A: M*K elements read + total_read_bytes += M * K * sizeof(ElementA); + + // Matrix B: K*N elements read + total_read_bytes += K * N * sizeof(ElementB); + + // Matrix C: M*N elements read (for beta operation) + total_read_bytes += M * N * sizeof(ElementC); + + // Block scales for A and B + auto blockscale_shape = cute::shape(cute::get<1>(cute::zipped_divide(cute::make_layout(problem), TileShape{}))); + auto blockscale_m = cute::get<0>(blockscale_shape); + auto blockscale_n = cute::get<1>(blockscale_shape); + auto blockscale_k = cute::get<2>(blockscale_shape); + auto groupscale_m = blockscale_m * ScaleMsPerTile; + auto groupscale_n = blockscale_n * ScaleNsPerTile; + + total_read_bytes += groupscale_m * blockscale_k * sizeof(ElementBlockScale); // A scales + total_read_bytes += groupscale_n * blockscale_k * sizeof(ElementBlockScale); // B scales + + // Matrix D: M*N elements written + total_write_bytes += M * N * sizeof(ElementD); + } + } + + return (total_read_bytes + total_write_bytes) / 1.0e9 / runtime_s; + } + + double bandwidth_util(double eff_bandwidth) const { + int memoryClockRate; + int memoryBusWidth; + cudaDeviceGetAttribute(&memoryClockRate, cudaDevAttrMemoryClockRate, 0); + cudaDeviceGetAttribute(&memoryBusWidth, cudaDevAttrGlobalMemoryBusWidth , 0); + double bw = 2.0 * memoryClockRate * (memoryBusWidth / 8) / 1.0e6; + return eff_bandwidth / bw * 100.0; + } + /// Prints the usage statement. std::ostream & print_usage(std::ostream &out) const { diff --git a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h deleted file mode 100644 index 1a94af670b..0000000000 --- a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h +++ /dev/null @@ -1,520 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Reference implementation for GETT in host-side code. -*/ - -#pragma once - -///////////////////////////////////////////////////////////////////////////////////////////////// -#include "cutlass/gemm/gemm.h" -#include "cutlass/complex.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/epilogue/thread/activation.h" -#include "cutlass/relatively_equal.h" -#include -#include "cute/tensor.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::reference::host { - -template -struct ElementTraits { - using type = T; -}; - -template -struct ElementTraits().get()), void> > > { - using type = decltype(std::declval().get()); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template< - class ElementAccumulator_, - class TensorA_, // (M, K, L) - class TensorB_, // (N, K, L) - class TensorScaleA_, // (m, k, L) - class TensorScaleB_, // (n, k, L) - class TileShape_ -> -struct GettMainloopParams { - using ElementAccumulator = ElementAccumulator_; - using TensorA = TensorA_; - using TensorB = TensorB_; - using EngineA = typename TensorA::engine_type; - using LayoutA = typename TensorA::layout_type; - using EngineB = typename TensorB::engine_type; - using LayoutB = typename TensorB::layout_type; - - using TensorScaleA = TensorScaleA_; - using TensorScaleB = TensorScaleB_; - using TileShape = TileShape_; - using EngineScaleA = typename TensorScaleA::engine_type; - using EngineScaleB = typename TensorScaleB::engine_type; - - TensorA A{}; - TensorB B{}; - TensorScaleA ScaleA{}; - TensorScaleB ScaleB{}; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -template< - class ElementScalar_, - class ElementScalingFactor_, - class ElementAccumulator_, - class ElementCompute_, - class TensorC_, // (M, N, L) - class TensorD_, // (M, N, L) - class VectorBias_ = TensorD_, // (M, 1) - class TensorAux_ = TensorD_, // (M, N, L) - class VectorAlpha_ = TensorD_, // (M, 1) - class VectorBeta_ = VectorAlpha_, // (M, 1) - class ActivationFunctor_ = cutlass::epilogue::thread::Identity, - class BiasBinaryOp_ = cutlass::plus, - bool PerColumnBias_ = false -> -struct GettEpilogueParams { - using ElementScalar = ElementScalar_; - using ElementScalingFactor = ElementScalingFactor_; - using ElementAccumulator = ElementAccumulator_; - using ElementCompute = ElementCompute_; - using TensorC = TensorC_; - using TensorD = TensorD_; - using TensorAux = TensorAux_; - using VectorBias = VectorBias_; - using VectorAlpha = VectorAlpha_; - using VectorBeta = VectorBeta_; - using ActivationFunctor = ActivationFunctor_; - using BiasBinaryOp = BiasBinaryOp_; - - using EngineC = typename TensorC::engine_type; - using LayoutC = typename TensorC::layout_type; - using EngineD = typename TensorD::engine_type; - using LayoutD = typename TensorD::layout_type; - static constexpr bool PerColumnBias = PerColumnBias_; - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - - TensorC C{}; - TensorD D{}; - VectorBias Bias{}; - TensorAux Aux{}; - VectorAlpha Valpha{}; - VectorBeta Vbeta{}; - ElementCompute st = ElementCompute(1); - - ElementAccumulator* abs_max_D = nullptr; - ElementAccumulator* abs_max_Aux = nullptr; - - ElementScalingFactor scale_a = ElementScalingFactor(1); - ElementScalingFactor scale_b = ElementScalingFactor(1); - ElementScalingFactor scale_c = ElementScalingFactor(1); - ElementScalingFactor scale_d = ElementScalingFactor(1); - ElementScalingFactor scale_aux = ElementScalingFactor(1); - - bool beta_per_channel_scaling = false; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GETT - General Tensor-Tensor contraction reference kernel with Groupwise scaling -template < - class MainloopParams, - class EpilogueParams -> -void Gett( - MainloopParams const& mainloop_params, - EpilogueParams const& epilogue_params) -{ - - static int constexpr kBlockM = cute::get<0>(typename MainloopParams::TileShape{}); - static int constexpr kBlockN = cute::get<1>(typename MainloopParams::TileShape{}); - // printf("mainloop_params.ScaleA.layout()"); cute::print(mainloop_params.ScaleA.layout()); printf("\n"); - // printf("mainloop_params.ScaleB.layout()"); cute::print(mainloop_params.ScaleB.layout()); printf("\n"); - -#if defined(_OPENMP) - #pragma omp parallel for collapse(3) -#endif - for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { - for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { - for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { - typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN]; - gett_mainloop(mainloop_params, m, n, l, acc); - gett_epilogue(epilogue_params, m, n, l, acc); - } - } - } -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GETT - Mainloop -template -void gett_mainloop( - MainloopParams const& mainloop_params, - int64_t m, - int64_t n, - int64_t l, - ElementAccumulator (&acc)[kBlockM][kBlockN]) -{ - - static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B"); - static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B"); - - using cute::raw_pointer_cast; - - using ElementA = typename ElementTraits::type; - using ElementB = typename ElementTraits::type; - using ElementBlockScaleA = typename ElementTraits::type; - using ElementBlockScaleB = typename ElementTraits::type; - - using RingOp = multiply_add; - RingOp fma_op; - - multiplies scale_op; - - static int constexpr kBlockK = cute::get<2>(typename MainloopParams::TileShape{});; - - // Tempo accumulators to seperate blockwise accumulation - typename MainloopParams::ElementAccumulator acc_temp[kBlockM][kBlockN]; - - // Zero out accumulators - for (int m_b = 0; m_b < kBlockM; ++m_b) { - for (int n_b = 0; n_b < kBlockN; ++n_b) { - acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity - acc_temp[m_b][n_b] = ElementAccumulator(0); - } - } - - const int M = cute::size<0>(mainloop_params.A.layout()); - const int N = cute::size<0>(mainloop_params.B.layout()); - - const int ScaleGranularityM = M / cute::size<0>(mainloop_params.ScaleA.layout()); - const int ScaleGranularityN = N / cute::size<0>(mainloop_params.ScaleB.layout()); - - assert(ScaleGranularityM && M % ScaleGranularityM == 0 && "ScaleGranularityM must divide M"); - assert(ScaleGranularityN && N % ScaleGranularityN == 0 && "ScaleGranularityN must divide N"); - - cute::Tensor blockscale_A = domain_offset(make_coord(m / ScaleGranularityM, _0{}), mainloop_params.ScaleA(_, _, l)); - cute::Tensor blockscale_B = domain_offset(make_coord(n / ScaleGranularityN, _0{}), mainloop_params.ScaleB(_, _, l)); - - // Compute on this k-block - for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) { - - // Load Blockwise scaling factor from blockscale Tensors for B - int64_t block_k = k / kBlockK; - cute::Tensor scale_a = blockscale_A(_, block_k); - cute::Tensor scale_b = blockscale_B(_, block_k); - - // Load A - ElementAccumulator a_frag[kBlockM]; - for (int m_b = 0; m_b < kBlockM; ++m_b) { - if (m + m_b < cute::size<0>(mainloop_params.A.layout())) { - // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. - a_frag[m_b] = static_cast(ElementA(mainloop_params.A(m + m_b, k, l))); - } else { - a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity - } - } - - // Load B - ElementAccumulator b_frag[kBlockN]; - for (int n_b = 0; n_b < kBlockN; ++n_b) { - if (n + n_b < cute::size<0>(mainloop_params.B.layout())) { - // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. - b_frag[n_b] = static_cast(ElementB(mainloop_params.B(n + n_b, k, l))); - } else { - b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity - } - } - - // do compute - for (int m_b = 0; m_b < kBlockM; ++m_b) { - for (int n_b = 0; n_b < kBlockN; ++n_b) { - acc_temp[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc_temp[m_b][n_b]); - } - } - - // Apply Groupwise-scaling at kBlockK boundary - // (a) Apply group and block scaling factors on the partial accumulated results (acc_temp) at the kBlocK boundary - // (b) Zero-out partial temporary (acc_temp), - // (c) Update permanent (accu) - if ((k+1) % kBlockK == 0) { - for (int m_b = 0; m_b < kBlockM; ++m_b) { - auto scale_a_m_b = scale_a[m_b / ScaleGranularityM]; - for (int n_b = 0; n_b < kBlockN; ++n_b) { - auto scale_b_n_b = scale_b[n_b / ScaleGranularityN]; - ElementAccumulator blockwise_scaled_accum = acc_temp[m_b][n_b] * scale_a_m_b * scale_b_n_b; - acc[m_b][n_b] = blockwise_scaled_accum + acc[m_b][n_b]; - acc_temp[m_b][n_b] = ElementAccumulator(0); - } - } - } - - } -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GETT - Epilogue -template -void gett_epilogue( - EpilogueParams const& epilogue_params, - int64_t m, - int64_t n, - int64_t l, - ElementAccumulator (&acc)[kBlockM][kBlockN]) -{ - static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B"); - static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B"); - - using cute::raw_pointer_cast; - - using ElementCompute = typename EpilogueParams::ElementCompute; - using ElementC = typename EpilogueParams::TensorC::value_type; - using ElementD = typename EpilogueParams::TensorD::value_type; - using ElementAux = typename EpilogueParams::TensorAux::value_type; - using ElementBias = typename EpilogueParams::VectorBias::value_type; - using ElementScalar = typename EpilogueParams::ElementScalar; - using ElementScalingFactor = typename EpilogueParams::ElementScalingFactor; - using ActivationFunctor = typename EpilogueParams::ActivationFunctor; - using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp; - - constexpr bool PerColBias = EpilogueParams::PerColumnBias; - constexpr bool IsScalingAndAmaxOutputNeeded = - cute::is_same_v or - cute::is_same_v; - - constexpr bool IsScalingAndAmaxAuxOutputNeeded = - cute::is_same_v or - cute::is_same_v; - - constexpr bool IsReLUAuxNeeded = - (cute::is_same_v> or - cute::is_same_v>) and - cute::is_same_v; - constexpr bool IsClamp = - cute::is_same_v>; - - constexpr bool IsBackpropFusion = - cute::is_same_v> or - cute::is_same_v>; - - // Input related converter - NumericConverter accumulator_converter; - NumericConverter source_converter; - NumericConverter bias_converter; - [[maybe_unused]] NumericConverter aux_source_converter; - - // Scale related converter - NumericConverter scale_converter; - NumericConverter scaling_factor_converter; - - // Abs max converter - [[maybe_unused]] NumericConverter abs_max_output_converter; - - // Output related converter - NumericConverter destination_converter; - [[maybe_unused]] NumericConverter aux_destination_converter; - NumericConverter dBias_converter; - - // Epilogue operations - multiply_add epilogue_fma; - multiplies mul; - plus add; - - // Activation operation - - auto activation = [] (ElementCompute x, ElementCompute y = ElementCompute(0)) { - if constexpr (std::is_same_v) { - return x + y; - } else { - return ActivationFunctor()(x, y); - } - }; - - // Bias binary operation - BiasBinaryOp bias_op; - - // Do conversion - ElementCompute converted_alpha = scale_converter(epilogue_params.alpha); - ElementCompute converted_beta = scale_converter(epilogue_params.beta); - ElementCompute converted_scale_a = scaling_factor_converter(epilogue_params.scale_a); - ElementCompute converted_scale_b = scaling_factor_converter(epilogue_params.scale_b); - ElementCompute converted_scale_c = scaling_factor_converter(epilogue_params.scale_c); - ElementCompute converted_scale_d = scaling_factor_converter(epilogue_params.scale_d); - ElementCompute converted_scale_aux = scaling_factor_converter(epilogue_params.scale_aux); - - // Init local var - [[maybe_unused]] ElementCompute local_abs_max_output = ElementCompute(0); - [[maybe_unused]] ElementCompute local_abs_max_aux_output = ElementCompute(0); - - converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b)); - converted_beta = mul(converted_beta, converted_scale_c); - - ElementCompute inter_accum[kBlockM][kBlockN]; - - for (int m_b = 0; m_b < kBlockM; ++m_b) { - ElementCompute local_dBias = ElementCompute(0); - - for (int n_b = 0; n_b < kBlockN; ++n_b) { - if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { - // Convert every type to ElementCompute first, do compute, convert to output type, write it out - ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]); - // per-row alpha - if (raw_pointer_cast(epilogue_params.Valpha.data())) { - converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b)); - } - ElementCompute output = mul(converted_alpha, converted_acc); - - if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) { - ElementCompute converted_bias = bias_converter(epilogue_params.Bias(PerColBias ? n + n_b : m + m_b)); - output = bias_op(output, converted_bias); - } - - if (raw_pointer_cast(epilogue_params.C.data())) { - ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l)); - // per-row beta - if (epilogue_params.Vbeta.data()) { - converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b)); - } - output = epilogue_fma(converted_beta, converted_src, output); - } - - if constexpr (IsBackpropFusion) { - ElementAux aux_input = ElementAux(0); - if (raw_pointer_cast(epilogue_params.Aux.data())) { - aux_input = epilogue_params.Aux(m + m_b, n + n_b, l); - } - - output = activation(output, aux_source_converter(aux_input)); - local_dBias = add(local_dBias, output); - } - else { - if (raw_pointer_cast(epilogue_params.Aux.data())) { - auto aux_output = output; - if constexpr (IsScalingAndAmaxAuxOutputNeeded) { - maximum_absolute_value_reduction amax_op; - local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output); - aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0)); - } - - if constexpr (IsReLUAuxNeeded) { - epilogue_params.Aux(m + m_b, n + n_b, l) = not (aux_output < 0) ? uint1b_t(1) : uint1b_t(0); - } else { - epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output); - } - } - - if constexpr (IsClamp) { // Treat Clamp as ReLU - output = activation(output, {0, std::numeric_limits::max()}); - } - else { - output = activation(output); - } - } - - if constexpr (IsScalingAndAmaxOutputNeeded) { - maximum_absolute_value_reduction amax_op; - local_abs_max_output = amax_op(local_abs_max_output, output); - output = epilogue_fma(converted_scale_d, output, ElementCompute(0)); - } - - inter_accum[m_b][n_b] = ElementCompute(output); - } - } // n_b - - if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n < cute::size<1>(epilogue_params.D.layout())) { - if (raw_pointer_cast(epilogue_params.Bias.data()) && IsBackpropFusion) { - ElementCompute converted_dBias = bias_converter(epilogue_params.Bias(m + m_b)); - local_dBias = add(local_dBias, converted_dBias); - epilogue_params.Bias(m + m_b) = dBias_converter(local_dBias); - } - } - } // m_b - for (int m_b = 0; m_b < kBlockM; ++m_b) { - for (int n_b = 0; n_b < kBlockN; ++n_b) { - if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { - epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]); - } - } - } - -#if defined(_OPENMP) - #pragma omp critical(Abs_Max_Data_Update) -#endif - { - if constexpr (IsScalingAndAmaxOutputNeeded) { - if (epilogue_params.abs_max_D) { - *epilogue_params.abs_max_D = maximum_with_nan_propogation{}( - *epilogue_params.abs_max_D, abs_max_output_converter(local_abs_max_output)); - } - } - - if constexpr (IsScalingAndAmaxAuxOutputNeeded) { - if (epilogue_params.abs_max_Aux) { - *epilogue_params.abs_max_Aux = maximum_with_nan_propogation{}( - *epilogue_params.abs_max_Aux, abs_max_output_converter(local_abs_max_aux_output)); - } - } - } -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GEMM - General Matrix-Matrix contraction without conjugation options -template < - class MainloopParams, - class EpilogueParams -> -void Gemm3x( - MainloopParams const& mainloop_params, - EpilogueParams const& epilogue_params) -{ - using namespace cute; - - static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{})); - static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{})); - static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{})); - static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "Only Rank3 Tensors (M, K, Batch_Count) " - "with Batchmode are supported"); - // Lower the Matrix-Multiplication with Groupwise scaling (Gemm3x) to a Tensor Contraction (Gett). - Gett(mainloop_params, epilogue_params); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // cutlass::reference::host - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_bf16_grouped_gemm.cu b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_bf16_grouped_gemm.cu index c1978c3212..9b56697bdc 100644 --- a/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_bf16_grouped_gemm.cu +++ b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_bf16_grouped_gemm.cu @@ -374,7 +374,7 @@ void allocate(Options const& options) { auto N = get<1>(problem); auto K = get<2>(problem); - const int scale_k = 1; + int const scale_k = cutlass::ceil_div(options.k, options.c); offset_A.push_back(total_elements_A); offset_B.push_back(total_elements_B * cutlass::sizeof_bits::value / 8); @@ -510,7 +510,7 @@ void initialize(Options &options) { beta_device.copy_from_host(ptr_beta_host.data()); initialize_tensor(block_A, seed + 2023); - initialize_quant_tensor(block_B, seed + 2022); + initialize_tensor(block_B, seed + 2022); initialize_tensor(block_C, seed + 2021); initialize_scale(block_scale, options); initialize_zero(block_zero, options); @@ -519,13 +519,13 @@ void initialize(Options &options) { for (int32_t i = 0; i < options.groups; ++i) { - const int scale_k = 1; + int const scale_k = cutlass::ceil_div(options.k, options.c); auto shape_B = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), cute::get<2>(options.problem_sizes_host[i]), Int<1>{}); auto shape_scale = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), scale_k, Int<1>{}); auto layout_B = make_layout(shape_B, stride_B_host.at(i)); auto layout_scale = make_layout(shape_scale, stride_S_host_ref.at(i)); cudaStream_t stream = cudaStreamDefault; - cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale, options.k, stream); + cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale, options.c, stream); } problem_sizes.reset(options.groups); @@ -619,7 +619,7 @@ typename Gemm::Arguments args_from_options(Options const& options, bool host_pro arguments = Args { cutlass::gemm::GemmUniversalMode::kGrouped, {options.groups, problem_sizes.get(), nullptr}, - {ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.k}, + {ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.c}, {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, hw_info }; @@ -676,6 +676,7 @@ bool verify(Options const& options) { for (int32_t i = 0; i < options.groups; ++i) { auto problem = options.problem_sizes_host.at(i); + // we don't swap and transpose in the verify so revert the problem shape. auto N = get<0>(problem); auto M = get<1>(problem); auto K = get<2>(problem); @@ -712,7 +713,7 @@ bool verify(Options const& options) { CUDA_CHECK(cudaDeviceSynchronize()); passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor); - std::cout << "Group: " << i << " Status: " << passed << std::endl; + std::cout << "Group " << i << ": " << options.problem_sizes_host[i] << ", alpha: " << alpha_host[i] << ", beta: " << beta_host[i] << " Status: " << passed << std::endl; } } return passed; diff --git a/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_fp8_grouped_gemm.cu b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_fp8_grouped_gemm.cu index 07ff66b31a..8407cdad5e 100644 --- a/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_fp8_grouped_gemm.cu +++ b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_fp8_grouped_gemm.cu @@ -341,7 +341,7 @@ void allocate(Options const& options) { auto N = get<1>(problem); auto K = get<2>(problem); - const int scale_k = 1; + int const scale_k = cutlass::ceil_div(options.k, options.c); offset_A.push_back(total_elements_A); offset_B.push_back(total_elements_B * cutlass::sizeof_bits::value / 8); @@ -479,7 +479,7 @@ void initialize(Options& options) { beta_device.copy_from_host(ptr_beta_host.data()); initialize_tensor(block_A, seed + 2023); - initialize_quant_tensor(block_B, seed + 2022); + initialize_tensor(block_B, seed + 2022); cutlass::unified_encode_int4b(block_B.get(), block_B_modified.get(), block_B.size()); initialize_tensor(block_C, seed + 2021); initialize_scale(block_scale, options); @@ -565,7 +565,7 @@ typename Gemm::Arguments args_from_options(Options const& options, bool host_pro arguments = Args { cutlass::gemm::GemmUniversalMode::kGrouped, {options.groups, problem_sizes.get(), nullptr}, - {ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale_packed.get(), stride_S.get(), options.k}, + {ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale_packed.get(), stride_S.get(), options.c}, {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, hw_info }; @@ -617,6 +617,7 @@ bool verify(Options const& options) { for (int32_t i = 0; i < options.groups; ++i) { auto problem = options.problem_sizes_host.at(i); + // we don't swap and transpose in the verify so revert the problem shape. auto N = get<0>(problem); auto M = get<1>(problem); auto K = get<2>(problem); @@ -630,11 +631,11 @@ bool verify(Options const& options) { stride_A_verif = cutlass::make_cute_packed_stride(StrideA_verif{}, cute::make_shape(M, K, 1)); stride_B_verif = cutlass::make_cute_packed_stride(StrideB_verif{}, cute::make_shape(N, K, 1)); - const int scale_k = 1; + int const scale_k = cutlass::ceil_div(options.k, options.c); auto layout_B = make_layout(cute::make_shape(N, K, Int<1>{}), stride_B_host.at(i)); auto layout_scale_zero = make_layout(cute::make_shape(N, scale_k, Int<1>{}), stride_S_host_ref.at(i)); cudaStream_t stream = cudaStreamDefault; - cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.k, stream); + cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.c, stream); // // Compute reference output @@ -659,7 +660,7 @@ bool verify(Options const& options) { CUDA_CHECK(cudaDeviceSynchronize()); passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor); - std::cout << "Group: " << i << " Status: " << passed << std::endl; + std::cout << "Group " << i << ": " << options.problem_sizes_host[i] << ", alpha: " << alpha_host[i] << ", beta: " << beta_host[i] << " Status: " << passed << std::endl; } } return passed; diff --git a/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_mixed_dtype_grouped_gemm.cu b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_mixed_dtype_grouped_gemm.cu index ffeb233ea5..41cccfbbf1 100644 --- a/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_mixed_dtype_grouped_gemm.cu +++ b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_mixed_dtype_grouped_gemm.cu @@ -282,7 +282,7 @@ void allocate(Options const& options) { auto N = get<1>(problem); auto K = get<2>(problem); - const int scale_k = 1; + int const scale_k = cutlass::ceil_div(options.k, options.c); offset_A.push_back(total_elements_A); offset_B.push_back(total_elements_B * cutlass::sizeof_bits::value / 8); @@ -418,7 +418,7 @@ void initialize(Options &options) { beta_device.copy_from_host(ptr_beta_host.data()); initialize_tensor(block_A, seed + 2023); - initialize_quant_tensor(block_B, seed + 2022); + initialize_tensor(block_B, seed + 2022); initialize_tensor(block_C, seed + 2021); initialize_scale(block_scale, options); initialize_zero(block_zero, options); @@ -485,7 +485,7 @@ typename Gemm::Arguments args_from_options(Options const& options, bool host_pro arguments = typename Gemm::Arguments { cutlass::gemm::GemmUniversalMode::kGrouped, {options.groups, problem_sizes.get(), nullptr}, - {ptr_B.get(), stride_B.get(), ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.k}, + {ptr_B.get(), stride_B.get(), ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.c}, {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, hw_info }; @@ -542,6 +542,7 @@ bool verify(Options const& options) { for (int32_t i = 0; i < options.groups; ++i) { auto problem = options.problem_sizes_host.at(i); + // we don't swap and transpose in the verify so revert the problem shape. auto N = get<0>(problem); auto M = get<1>(problem); auto K = get<2>(problem); @@ -555,11 +556,11 @@ bool verify(Options const& options) { stride_A_verif = cutlass::make_cute_packed_stride(StrideA_verif{}, cute::make_shape(M, K, 1)); stride_B_verif = cutlass::make_cute_packed_stride(StrideB_verif{}, cute::make_shape(N, K, 1)); - const int scale_k = 1; + int const scale_k = cutlass::ceil_div(options.k, options.c); auto layout_B = make_layout(cute::make_shape(N, K, Int<1>{}), stride_B_host.at(i)); auto layout_scale_zero = make_layout(cute::make_shape(N, scale_k, Int<1>{}), stride_S_host_ref.at(i)); cudaStream_t stream = cudaStreamDefault; - cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.k, stream); + cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.c, stream); // // Compute reference output @@ -584,7 +585,7 @@ bool verify(Options const& options) { CUDA_CHECK(cudaDeviceSynchronize()); passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor); - std::cout << "Group: " << i << " Status: " << passed << std::endl; + std::cout << "Group " << i << ": " << options.problem_sizes_host[i] << ", alpha: " << alpha_host[i] << ", beta: " << beta_host[i] << " Status: " << passed << std::endl; } } return passed; diff --git a/examples/69_hopper_mixed_dtype_grouped_gemm/CMakeLists.txt b/examples/69_hopper_mixed_dtype_grouped_gemm/CMakeLists.txt index 4c21cd4854..f32c5d527f 100644 --- a/examples/69_hopper_mixed_dtype_grouped_gemm/CMakeLists.txt +++ b/examples/69_hopper_mixed_dtype_grouped_gemm/CMakeLists.txt @@ -50,6 +50,7 @@ set(TEST_RANDOM_PERF_LARGE_GROUP --groups=100 --iterations=10) set(TEST_DIRECT_BATCHED --m=2048 --n=5120 --k=8192 --mode=0 --iterations=0) # Direct conversion set(TEST_SCALE_PERCOL --m=4096 --n=5120 --k=8192 --c=8192 --mode=1 --iterations=0) # Per Column scaling +set(TEST_SCALE_GROUP --m=2048 --n=5120 --k=8192 --c=512 --mode=1 --iterations=0) # Group-wise scaling cutlass_example_add_executable( 69_hopper_mixed_dtype_grouped_gemm @@ -69,6 +70,7 @@ cutlass_example_add_executable( TEST_RANDOM_PERF_LARGE_GROUP TEST_DIRECT_BATCHED TEST_SCALE_PERCOL + TEST_SCALE_GROUP ) cutlass_example_add_executable( @@ -89,6 +91,7 @@ cutlass_example_add_executable( TEST_RANDOM_PERF_LARGE_GROUP TEST_DIRECT_BATCHED TEST_SCALE_PERCOL + TEST_SCALE_GROUP ) cutlass_example_add_executable( @@ -109,4 +112,5 @@ cutlass_example_add_executable( TEST_RANDOM_PERF_LARGE_GROUP TEST_DIRECT_BATCHED TEST_SCALE_PERCOL + TEST_SCALE_GROUP ) diff --git a/examples/69_hopper_mixed_dtype_grouped_gemm/README.md b/examples/69_hopper_mixed_dtype_grouped_gemm/README.md index f4d71ea3f1..10b57aa08c 100644 --- a/examples/69_hopper_mixed_dtype_grouped_gemm/README.md +++ b/examples/69_hopper_mixed_dtype_grouped_gemm/README.md @@ -7,11 +7,11 @@ This example shows how to perform Grouped GEMMs on Hopper when A and B have diff - in the arguments, pass the group size, array of the problem sizes, and the array of strides for matrix A and B. - if scales and zero-points are included, also pass the array of their strides in the arguments. -Note that in Example 55, the argument `--g` is used to determine the block scale size. It is important not to confuse this with the `--groups` argument in this example, which specifies the number of GEMMs. +Note that in Example 55, the argument `--g` is used to determine the group size of scaling. To avoid confusion with the `--groups` argument in this example, which defines the number of GEMMs, `--c` is used here to represent the group size for scaling. ## Upcoming features -Currently, the Mixed-input Grouped GEMM only supports row-wise scaling. Please contact us if zero-points or block-wise scaling are needed. +Currently, the Mixed-input Grouped GEMM only supports row-wise scaling, and group-wise scaling for identical problem shapes across all groups. Please contact us if zero-points or block-wise scaling are needed. ## Copyright diff --git a/examples/69_hopper_mixed_dtype_grouped_gemm/grouped_mixed_dtype_utils.hpp b/examples/69_hopper_mixed_dtype_grouped_gemm/grouped_mixed_dtype_utils.hpp index db391cce8f..8568b467dd 100644 --- a/examples/69_hopper_mixed_dtype_grouped_gemm/grouped_mixed_dtype_utils.hpp +++ b/examples/69_hopper_mixed_dtype_grouped_gemm/grouped_mixed_dtype_utils.hpp @@ -58,6 +58,7 @@ class GroupedMixedDtypeOptions : public MixedDtypeOptions { void parse(int argc, char const **args) { cutlass::CommandLine cmd(argc, args); cmd.get_cmd_line_argument("groups", groups); + cmd.get_cmd_line_argument("benchmark", benchmark_path); cmd.get_cmd_line_argument("c", c); MixedDtypeOptions::parse(argc, args); @@ -71,6 +72,7 @@ class GroupedMixedDtypeOptions : public MixedDtypeOptions { << " --m= Sets the M extent of the GEMM for all groups\n" << " --n= Sets the N extent of the GEMM for all groups\n" << " --k= Sets the K extent of the GEMM for all groups\n" + << " --c= Sets the chunk size for scaling the quantized weights\n" << " --groups= Sets the number of individual GEMM problems\n" << " --mode= The mode to run the gemm\n" << " --alpha= Epilogue scalar alpha\n" @@ -183,11 +185,6 @@ void grouped_mixed_dtype_profiling( result.avg_runtime_ms = std::accumulate(runtimes.begin(), runtimes.end(), 0.0f) / runtimes.size(); result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); - - std::cout << " Problem Sizes, Alpha, Beta\n"; - for (int32_t i = 0; i < options.groups; ++i) { - std::cout << " " << options.problem_sizes_host[i] << ", " << alpha_host[i] << ", " << beta_host[i] << '\n'; - } std::cout << " Groups : " << options.groups << '\n' << " Avg runtime : " << result.avg_runtime_ms << " ms\n" << " GFLOPS : " << result.gflops << '\n'; diff --git a/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu b/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu index 75d3437d1b..8be4f6395d 100644 --- a/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu +++ b/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu @@ -480,7 +480,12 @@ bool verify(const Options &options) { passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0); passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0); - return passed; + block_SFD.sync_host(); + bool passed_sfd = cutlass::reference::host::TensorEquals(block_reference_SFD.host_view(), block_SFD.host_view()); + passed_sfd &= (cutlass::reference::host::TensorNorm(block_reference_SFD.host_view()) > 0); + passed_sfd &= (cutlass::reference::host::TensorNorm(block_SFD.host_view()) > 0); + + return passed && passed_sfd; } /// Execute a given example GEMM computation diff --git a/examples/77_blackwell_fmha/77_blackwell_fmha.cu b/examples/77_blackwell_fmha/77_blackwell_fmha.cu index 1d1314d145..c879212223 100644 --- a/examples/77_blackwell_fmha/77_blackwell_fmha.cu +++ b/examples/77_blackwell_fmha/77_blackwell_fmha.cu @@ -67,9 +67,6 @@ --b=2048 --h=2048 --d=2048 --q=2048 --k=2048 */ -#define DSHOW(x) print(#x ": "); print(x); print("\n"); -#define DSHOWT(x) print(#x ": "); print_tensor(x); print("\n"); - #include #include #include @@ -247,8 +244,8 @@ struct Options { << " and are split B-ways, alternatingly +10% and -10%\n" << " with the last batch sized to make it fit\n" << " implies at least residual masking for correctness\n" - << " --sm-count Sets SM count rather than querying it\n" - << " --kernel-filter= Sets regexp to match kernel against\n" + << " --sm-count Sets SM count rather than querying it\n" + << " --kernel-filter= Sets regexp to match kernel against\n" << "\n"; return out; diff --git a/examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu b/examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu new file mode 100644 index 0000000000..1c02a29ef0 --- /dev/null +++ b/examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu @@ -0,0 +1,865 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Example implementation of fused multi-head attention for Blackwell using CUTLASS 3. + + This example showcases the use of CUTLASS to build backward fused + multi-head attantion (FMHA) collectives from existing CUTLASS collectives targeting + the NVIDIA Blackwell architecture. + + Background and motivation + ------------------------- + CUTLASS is a highly flexible library that provides open-source building blocks + for tensor core programming for GEMM or GEMM-like problems. Fused multi-head + attention (FMHA) is a foundational kernel for large language models (LLMs) since it + makes long sequence lengths feasible from a memory-usage perspective. It also + improves computational efficiency since it transforms an outer-product-like and + a matrix-vector-like GEMM into a fused operation with much higher arithmetic + intensity. For more details, see Dao et al, 2022; Dao, 2023. + Implementing this kernel in CUTLASS enabled easy customization and high + performance. + + Introduction + ------------ + The example targets the NVIDIA Blackwell architecture, and takes advantage of + 5th gen tensor cores and the Tensor Memory Accelerator (TMA), just like + GEMMs do. It provides a backward pass (often abbreviated + bwd in the code). + The code is structured into three layers: The runner (and the reference kernels) + takes care of initialization, measurement, and testing; the device layer + orchestrates kernel calls and partitions workspace; and the kernel layer (just + like the CUTLASS kernel layer. + + Support + --------- + + We support fp16 and fp8 data types with a head dimension of 128. + + Example usage: + $ ./examples/77_blackwell_fmha/77_blackwell_fmha_bwd_fp16 \ + --b=2048 --h=2048 --d=2048 --q=2048 --k=2048 +*/ + +#include +#include +#include + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "reference/fmha_fwd_reference.hpp" +#include "reference/fmha_bwd_reference.hpp" +#include "reference/reference_abs_error.hpp" + +#include "collective/fmha_fusion.hpp" +#include "device/fmha_device_bwd.hpp" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +using namespace cute; +using namespace cutlass::fmha::kernel; +using namespace cutlass::fmha::collective; +using namespace cutlass::fmha; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +enum class InitStyle { + kOne, kZero, kLinearStride128, kLinearStride1, kRandom, kNone +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Command line options parsing +struct Options { + + bool help = false; + bool error = false; + + int b = 16; + int h = 16; + int h_k = 1; + int q = 1024; + int k = 1024; + int d = 128; + int iterations = 3; + bool verify = false; + bool verbose = false; + + bool causal = false; + int sm_count = 0; + + std::string kernel_filter; + + InitStyle init_style_q = InitStyle::kRandom; + InitStyle init_style_k = InitStyle::kRandom; + InitStyle init_style_v = InitStyle::kRandom; + InitStyle init_style_do = InitStyle::kRandom; + bool skip_reference = false; + + static void get_init_style_argument(cutlass::CommandLine& cmd, const char* name, InitStyle& dst, InitStyle const& src) { + std::string s; + cmd.get_cmd_line_argument(name, s, s); + if (s.empty()) { + dst = src; + } + else { + if (s == "r") { + dst = InitStyle::kRandom; + } + else if (s == "0") { + dst = InitStyle::kZero; + } + else if (s == "1") { + dst = InitStyle::kOne; + } + else if (s == "d") { + dst = InitStyle::kLinearStride1; + } + else if (s == "s") { + dst = InitStyle::kLinearStride128; + } + else if (s == "n") { + dst = InitStyle::kNone; + } + else { + std::cout << "Error: " << s << " is not a valid input type.\n"; + std::exit(-1); + } + } + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + Options defaults; + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("d", d, defaults.d); + cmd.get_cmd_line_argument("h", h, -1); + if (h == -1) h = 2048 / d; + + cmd.get_cmd_line_argument("q", q, -1); + cmd.get_cmd_line_argument("k", k, -1); + if (q == -1) q = k; + if (k == -1) k = q; + if (q == -1 && k == -1) q = k = defaults.q; + + cmd.get_cmd_line_argument("b", b, -1); + if (b == -1) b = 16384 / k; + if (b == 0) b = 1; + + cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations); + verify = cmd.check_cmd_line_flag("verify"); + verbose = cmd.check_cmd_line_flag("verbose"); + std::string mask; + cmd.get_cmd_line_argument("mask", mask, ""); + if (mask == "causal") { + causal = true; + } + else { + causal = defaults.causal; + } + + skip_reference = cmd.check_cmd_line_flag("skip-reference"); + cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count); + + get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q); + get_init_style_argument(cmd, "init-style", init_style_k, defaults.init_style_k); + get_init_style_argument(cmd, "init-style", init_style_v, defaults.init_style_v); + get_init_style_argument(cmd, "init-style", init_style_do, defaults.init_style_do); + get_init_style_argument(cmd, "init-style-q", init_style_q, init_style_q); + get_init_style_argument(cmd, "init-style-k", init_style_k, init_style_k); + get_init_style_argument(cmd, "init-style-v", init_style_v, init_style_v); + get_init_style_argument(cmd, "init-style-do", init_style_v, init_style_do); + + cmd.get_cmd_line_argument("kernel-filter", kernel_filter, defaults.kernel_filter); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "77_blackwell_fmha_bwd\n\n" + << " This example showcases the use of CUTLASS's collective operation builders to easily construct\n" + << " fused multi-head attention kernels for the backward pass targeting NVIDIA's Blackwell architecture.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --b= Sets the B extent\n" + << " --h= Sets the H extent\n" + << " --q= Sets the Q extent\n" + << " --k= Sets the K extent\n" + << " --d= Sets the D extentn" + << " --iterations= Benchmarking iterations\n" + << " --verify Verify results\n" + << " --verbose Print smem and execution time per kernel\n" + << " --mask= Enables masking\n" + << " --sm-count Sets SM count rather than querying it\n" + << " --kernel-filter= Sets regexp to match kernel against\n" + << "\n"; + + return out; + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +void initialize_block( + DeviceAllocation& block, + uint64_t seed=2023, InitStyle init_style = InitStyle::kRandom) { + + switch (init_style) { + case InitStyle::kOne: { + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, (Element) 1, (Element) 1); + break; + } + case InitStyle::kZero: { + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, (Element) 0, (Element) 0); + break; + } + case InitStyle::kRandom: { + cutlass::reference::device::BlockFillRandomGaussian( + block.get(), block.size(), seed, (Element) 0, (Element) 1); + break; + } + case InitStyle::kLinearStride1: { + std::vector data(block.size()); + for (size_t i = 0; i < block.size() / 128; i ++) { + for (int j = 0; j < 128; j++) { + data[j + 128*i] = static_cast((double) (j % 4)); + } + } + block.copy_from_host(data.data(), data.size()); + break; + } + case InitStyle::kLinearStride128: { + std::vector data(block.size()); + for (size_t i = 0; i < block.size() / 128; i ++) { + for (int j = 0; j < 128; j++) { + data[j + 128*i] = static_cast((double) (i % 4)); + } + } + block.copy_from_host(data.data(), data.size()); + break; + } + case InitStyle::kNone: { + break; + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +struct ExampleResult { + bool passed = false; + bool verified = false; + float runtime_ms = 0; + double tflops_tc_s = 0; + size_t smem_size = 0; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class TileShape, + class DispatchPolicy, + class ActiveMask, + class... KernelOptions +> +struct BwdRunner { + +#ifdef FP8 + using Element = cutlass::float_e4m3_t; +#else + using Element = cutlass::half_t; +#endif + using ElementAccumulator = float; + + // Q K D (H B) + using ProblemShapeType = cute::tuple>; + + using Operation = cutlass::fmha::device::Sm100FmhaBwd; + + using TensorStride = Stride>; // Seq D (H B) + using StrideQ = TensorStride; + using StrideK = TensorStride; + using StrideV = TensorStride; + using StrideO = TensorStride; + using StrideLSE = Stride<_1, Stride>; // Seq (H B) + + // Backwards specific + using StrideDQ = TensorStride; + using StrideDK = TensorStride; + using StrideDV = TensorStride; + using StrideDO = TensorStride; + + // + // Data members + // + + /// Initialization + StrideQ stride_Q; + StrideK stride_K; + StrideV stride_V; + StrideO stride_O; + StrideLSE stride_LSE; + + StrideDQ stride_dQ; + StrideDK stride_dK; + StrideDV stride_dV; + StrideDO stride_dO; + + uint64_t seed = 0; + + DeviceAllocation block_Q; + DeviceAllocation block_K; + DeviceAllocation block_V; + DeviceAllocation block_O; + DeviceAllocation block_LSE; + + DeviceAllocation block_dQ; + DeviceAllocation block_dK; + DeviceAllocation block_dV; + DeviceAllocation block_dO; + + DeviceAllocation block_ref_dQ; + DeviceAllocation block_ref_dK; + DeviceAllocation block_ref_dV; + + // + // Methods + // + bool verify(const ProblemShapeType& problem_shape) { + auto [Q, K, D, HB] = problem_shape; + auto [H, B] = HB; + + Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()), + select<0,2,3>(problem_shape), + stride_Q); + + Tensor mK = make_tensor(make_gmem_ptr(block_K.get()), + select<1,2,3>(problem_shape), + stride_K); + + Tensor mV = make_tensor(make_gmem_ptr(block_V.get()), + select<1,2,3>(problem_shape), + stride_V); + + Tensor mO = make_tensor(make_gmem_ptr(block_O.get()), + select<0,2,3>(problem_shape), + stride_O); + + // keep going here! (this might be better in cursor) + + Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()), + select<0,3>(problem_shape), + stride_LSE); + + Tensor mDQ = make_tensor(make_gmem_ptr(block_ref_dQ.get()), + select<0,2,3>(problem_shape), + stride_dQ); + + Tensor mDK = make_tensor(make_gmem_ptr(block_ref_dK.get()), + select<1,2,3>(problem_shape), + stride_dK); + + Tensor mDV = make_tensor(make_gmem_ptr(block_ref_dV.get()), + select<1,2,3>(problem_shape), + stride_dV); + + Tensor mDO = make_tensor(make_gmem_ptr(block_dO.get()), + select<0,2,3>(problem_shape), + stride_dO); + + fmha_bwd_reference(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, mDK, mDV, ActiveMask{}); + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Reference kernel failed. Last CUDA error: " + << cudaGetErrorString(result) << std::endl; + return false; + } + + const double kMaxDiffThresh = sizeof(Element) == 1 ? 1e-0 : 1e-2; + const double kMeanDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-3; + + // Check if output from CUTLASS kernel and reference kernel are equal or not + double max_diff = 0; + double mean_diff = 0; + reference_abs_diff(block_dQ, block_ref_dQ, max_diff, mean_diff); + + bool passed_dQ = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); + if (! passed_dQ) { + std::cerr << "failed dQ: max diff " << max_diff + << " mean " << mean_diff << std::endl; + } + + reference_abs_diff(block_dK, block_ref_dK, max_diff, mean_diff); + + bool passed_dK = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); + if (! passed_dK) { + std::cerr << "failed dK: max diff " << max_diff + << " mean " << mean_diff << std::endl; + } + + reference_abs_diff(block_dV, block_ref_dV, max_diff, mean_diff); + + bool passed_dV = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); + if (! passed_dV) { + std::cerr << "failed dV: max diff " << max_diff + << " mean " << mean_diff << std::endl; + } + + return passed_dQ && passed_dK && passed_dV; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_shape, Options const& options) { + auto [Q, K, D, HB] = problem_shape; + auto [H, B] = HB; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + + auto shape_QO = select<0,2,3>(problem_shape); + auto shape_KV = select<1,2,3>(problem_shape); + auto shape_LSE = select<0,3>(problem_shape); + + stride_Q = make_stride(D, _1{}, make_stride(D*Q, D*Q*H)); + stride_K = make_stride(D, _1{}, make_stride(D*K, D*K*H)); + stride_V = stride_K; + stride_O = stride_Q; + stride_LSE = make_stride(_1{}, make_stride(Q, Q*H)); + + stride_dQ = stride_Q; + stride_dK = stride_K; + stride_dV = stride_V; + stride_dO = stride_O; + + auto lsize = [](auto shape) { + return size(make_shape(1ull, shape)); + }; + + block_Q.reset(lsize(shape_QO)); + block_K.reset(lsize(shape_KV)); + block_V.reset(lsize(shape_KV)); + block_O.reset(lsize(shape_QO)); + block_LSE.reset(lsize(shape_LSE)); + + block_dQ.reset(lsize(shape_QO)); + block_dK.reset(lsize(shape_KV)); + block_dV.reset(lsize(shape_KV)); + block_dO.reset(lsize(shape_QO)); + + block_ref_dQ.reset(lsize(shape_QO)); + block_ref_dK.reset(lsize(shape_KV)); + block_ref_dV.reset(lsize(shape_KV)); + + initialize_block(block_Q, seed + 2023, options.init_style_q); + initialize_block(block_K, seed + 2022, options.init_style_k); + initialize_block(block_V, seed + 2021, options.init_style_v); + initialize_block(block_dO, seed + 2020, options.init_style_do); + + Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()), + select<0,2,3>(problem_shape), + stride_Q); + + Tensor mK = make_tensor(make_gmem_ptr(block_K.get()), + select<1,2,3>(problem_shape), + stride_K); + + Tensor mV = make_tensor(make_gmem_ptr(block_V.get()), + select<1,2,3>(problem_shape), + stride_V); + + Tensor mO = make_tensor(make_gmem_ptr(block_O.get()), + select<0,2,3>(problem_shape), + stride_O); + + Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()), + select<0,3>(problem_shape), + stride_LSE); + + if (! options.skip_reference) { + fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{}); + } + } + + ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + auto problem_shape = make_shape(options.q, options.k, options.d, make_shape(options.h, options.b)); + + initialize(problem_shape, options); + + ElementAccumulator softmax_scale = 1.0f / sqrtf(options.d); + + typename Operation::Arguments arguments{ + problem_shape, + block_Q.get(), stride_Q, + block_K.get(), stride_K, + block_V.get(), stride_V, + block_O.get(), stride_O, + block_LSE.get(), stride_LSE, + block_dO.get(), stride_dO, + block_dQ.get(), stride_dQ, + block_dK.get(), stride_dK, + block_dV.get(), stride_dV, + softmax_scale, + hw_info + }; + + Operation op; + + ExampleResult example_result; + + example_result.smem_size = Operation::Kernel::SharedStorageSize; + + size_t workspace_size = 0; + workspace_size = Operation::get_workspace_size(arguments); + DeviceAllocation workspace(workspace_size); + + cutlass::Status status = cutlass::Status::kSuccess; + status = op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + std::cerr << "This kernel is not supported. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + + status = op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + + // Run + status = op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result = cudaEventCreate(&event); + if (result != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + } + + // Record an event at the start of a series of GEMMs + result = cudaEventRecord(events[0]); + if (result != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + for (int i = 0; i < options.iterations; i++) { + status = op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + } + + // + // Stop profiling loop + // + + // Record an event when the GEMMs are complete + result = cudaEventRecord(events[1]); + if (result != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // Wait for work on the device to complete. + result = cudaEventSynchronize(events[1]); + if (result != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + runtime_ms /= static_cast(options.iterations); + + double flops = 10.0 * (std::is_same_v ? 0.5 : 1.0); + flops *= static_cast(get<0>(problem_shape)); + flops *= static_cast(get<1>(problem_shape)); + flops *= static_cast(get<2>(problem_shape)); + flops *= static_cast(get<3,0>(problem_shape)); + flops *= static_cast(get<3,1>(problem_shape)); + double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/); + example_result.tflops_tc_s = tflops_s; + example_result.runtime_ms = runtime_ms; + + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // Verify that the result is correct + bool passed = true; + if (options.verify) { + passed = verify(problem_shape); + if (passed) example_result.verified = true; + } + + if (!passed) { + std::cerr << "Reference check failed" << std::endl; + return example_result; + } + + example_result.passed = true; + + return example_result; + } + +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to print a description of the example run and its result +void print_result(const std::string& description, ExampleResult result, bool verbose) { + std::ios fmt(nullptr); + fmt.copyfmt(std::cout); + std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] "); + std::cout << std::setw(32) << std::left << description; + std::cout.copyfmt(fmt); + std::cout << " : " << result.tflops_tc_s << " TFLOPS/s" << std::endl; + if (verbose) { + std::cout << " t=" << result.runtime_ms << "ms, " + "smem=" << result.smem_size << "b" << std::endl; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +struct KernelCoop {}; + +////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) { + auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) { + BwdRunner runner; + auto result = runner.run(options, hw_info); + print_result(name, result, options.verbose); + }; + + using HeadDim = _64; + + run(Shape<_128, _128, HeadDim>{}, KernelCoop{}, "tma"); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_bwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) { + auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) { + BwdRunner runner; + auto result = runner.run(options, hw_info); + print_result(name, result, options.verbose); + }; + + using HeadDim = _128; + + run(Shape<_128, _128, HeadDim>{}, KernelCoop{}, "tma"); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main_single(int argc, char const **args) { + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (__CUDACC_VER_MAJOR__ < 12 || props.major != 10) { + std::cout + << "This example requires a GPU of NVIDIA's Blackwell Architecture " + << "(compute capability 100a) and CUDA 12.8 or greater.\n"; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + if (options.sm_count == 0) { + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + else { + hw_info.sm_count = options.sm_count; + } + + std::cout << "###### B " << options.b << " H " << options.h << " Q " << options.q << " K " << options.k << " D " << options.d << " "; + std::cout << "Backward" << " " << (options.causal ? "Causal" : "Full") << " "; + std::cout << "#SM " << hw_info.sm_count << std::endl; + + auto with_causal = [&](auto fn) { + if (options.causal) { + fn(CausalMask{}); + } + else { + fn(NoMask{}); + } + }; + + with_causal([&](auto fusion) { + if (options.d <= 64) { + run_bwd_64(fusion, options, hw_info); + } + else if (options.d <= 128) { + run_bwd_128(fusion, options, hw_info); + } + else { + std::cout << "No kernel instantiated for d=" << options.d << std::endl; + } + }); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + std::vector full_arguments(args, args + argc); + + int result = 0; + + bool recursed = false; + for (size_t i = 1; i < full_arguments.size(); i++) { + if (full_arguments[i].find(',') != std::string::npos) { + auto arg = full_arguments[i]; + size_t eq_pos = arg.find('='); + std::string prefix = eq_pos == std::string::npos ? "" : arg.substr(0, eq_pos+1); + std::string rest = eq_pos == std::string::npos ? arg : arg.substr(eq_pos+1); + for (;;) { + size_t comma_pos = rest.find(','); + std::string current = rest.substr(0, comma_pos); + full_arguments[i] = prefix + current; + std::vector next_args; + for (auto& elem : full_arguments) { next_args.push_back(elem.data()); } + main(argc, next_args.data()); + if (comma_pos == std::string::npos) break; + rest = rest.substr(comma_pos+1); + } + recursed = true; + break; + } + } + + if (! recursed) { + main_single(argc, args); + } + + return result; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/77_blackwell_fmha/77_blackwell_mla.cu b/examples/77_blackwell_fmha/77_blackwell_mla.cu new file mode 100644 index 0000000000..baa70fce18 --- /dev/null +++ b/examples/77_blackwell_fmha/77_blackwell_mla.cu @@ -0,0 +1,832 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file A MLA (Multi-Head Latent Attention) inference kernel sample for the + NVIDIA Blackwell Architecture. +*/ + +#include +#include +#include +#include + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "reference/fmha_mla_reference.hpp" +#include "reference/reference_abs_error.hpp" + +#include "device/sm100_mla.hpp" +#include "kernel/sm100_mla_tile_scheduler.hpp" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +using namespace cute; +using namespace cutlass::fmha::kernel; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +enum class InitStyle { + kOne, kLinearStride128, kLinearStride1, kRandom, kNone +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Command line options parsing +struct Options { + + bool help = false; + bool error = false; + + int b = 1; + int k = 256; + int split_kv = -1; // number of split along k dim. + bool is_var_split_kv = false; + int max_split_kv = 16; + int page = -1; + float spread = 0.2f; + int iterations = 3; + bool verify = false; + bool verbose = false; + + int sm_count = 0; + + std::string kernel_filter; + + InitStyle init_style_q = InitStyle::kRandom; + InitStyle init_style_c = InitStyle::kRandom; + + static void get_init_style_argument(cutlass::CommandLine& cmd, const char* name, InitStyle& dst, InitStyle const& src) { + std::string s; + cmd.get_cmd_line_argument(name, s, s); + if (s.empty()) { + dst = src; + } + else { + if (s == "r") { + dst = InitStyle::kRandom; + } + else if (s == "1") { + dst = InitStyle::kOne; + } + else if (s == "d") { + dst = InitStyle::kLinearStride1; + } + else if (s == "s") { + dst = InitStyle::kLinearStride128; + } + else if (s == "n") { + dst = InitStyle::kNone; + } + else { + std::cout << "Error: " << s << " is not a valid input type.\n"; + std::exit(-1); + } + } + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + Options defaults; + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("k", k, -1); + if (k == -1) k = defaults.k; + + cmd.get_cmd_line_argument("b", b, -1); + if (b == -1) b = 16384 / k; + if (b == 0) b = 1; + + cmd.get_cmd_line_argument("split_kv", split_kv, defaults.split_kv); + cmd.get_cmd_line_argument("page", page, defaults.page); + cmd.get_cmd_line_argument("spread", spread, defaults.spread); + cmd.get_cmd_line_argument("is_var_split_kv", is_var_split_kv, false); + if (page == -1) { + is_var_split_kv = false; + } + cmd.get_cmd_line_argument("max_split_kv", max_split_kv, defaults.max_split_kv); + if (is_var_split_kv == true) { + split_kv = max_split_kv; + } + cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations); + verify = cmd.check_cmd_line_flag("verify"); + verbose = cmd.check_cmd_line_flag("verbose"); + cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count); + + get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q); + get_init_style_argument(cmd, "init-style", init_style_c, defaults.init_style_c); + get_init_style_argument(cmd, "init-style-q", init_style_q, init_style_q); + get_init_style_argument(cmd, "init-style-c", init_style_c, init_style_c); + + cmd.get_cmd_line_argument("kernel-filter", kernel_filter, defaults.kernel_filter); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "77_blackwell_mla\n\n" + << " This example showcases the use of CUTLASS for fused multi-head latent\n" + << " attention kernels targeting NVIDIA's Blackwell architecture.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --b= Sets the B extent\n" + << " --k= Sets the K extent\n" + << " --page= Enables paging and sets the page size\n" + << " --iterations= Benchmarking iterations\n" + << " --spread= Relative spread away from K for paging\n" + << " --split_kv= Split KV factor\n" + << " --verify Verify results\n" + << " --verbose Print smem and execution time per kernel\n" + << " --sm-count Sets SM count rather than querying it\n" + << " --kernel-filter= Sets regexp to match kernel against\n" + << "\n"; + + return out; + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +void initialize_block( + DeviceAllocation& block, + uint64_t seed=2023, InitStyle init_style = InitStyle::kRandom) { + + switch (init_style) { + case InitStyle::kOne: { + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, (Element) 1, (Element) 1); + break; + } + case InitStyle::kRandom: { + cutlass::reference::device::BlockFillRandomGaussian( + block.get(), block.size(), seed, (Element) -1, (Element) 1); + break; + } + case InitStyle::kLinearStride1: { + std::vector data(block.size()); + for (size_t i = 0; i < block.size() / 128; i ++) { + for (int j = 0; j < 128; j++) { + data[j + 128*i] = static_cast((double) (j % 4)); + } + } + block.copy_from_host(data.data(), data.size()); + break; + } + case InitStyle::kLinearStride128: { + std::vector data(block.size()); + for (size_t i = 0; i < block.size() / 64; i ++) { + for (int j = 0; j < 64; j++) { + data[j + 64*i] = static_cast((double) (i % 9)); + } + } + block.copy_from_host(data.data(), data.size()); + break; + } + case InitStyle::kNone: { + break; + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +struct ExampleResult { + bool passed = false; + bool verified = false; + float runtime_ms = 0; + double tflops_tc_s = 0; + double tbytes_s = 0; + size_t smem_size = 0; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct IsPersistent { + static const bool value = v; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class TileShape, + class PersistenceOption = IsPersistent +> +struct Runner { + +#ifdef FP8 + using Element = cutlass::float_e4m3_t; +#elif FP16 + using Element = cutlass::half_t; +#else + #error "Must either define FP8 or FP16" +#endif + + using ElementAcc = float; + using ElementOut = cutlass::half_t; + + using TileShapeH = cute::tuple_element_t<0, TileShape>; + using TileShapeD = cute::tuple_element_t<2, TileShape>; + + // H K (D_latent D_rope) B + using ProblemShape = cute::tuple; + + using StrideQ = cute::tuple; // H D B + using StrideK = cute::tuple; // K D B + using StrideO = StrideK; // H D B + using StrideLSE = cute::tuple<_1, int>; // H B + + using TileScheduler = std::conditional_t< + PersistenceOption::value, + Sm100MlaPersistentTileScheduler, + Sm100MlaIndividualTileScheduler + >; + + using Kernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized< + TileShape, Element, ElementAcc, ElementOut, ElementAcc, TileScheduler + >; + using Operation = cutlass::fmha::device::MLA; + + // + // Data members + // + + /// Initialization + StrideQ stride_Q_latent; + StrideK stride_C_latent; + StrideQ stride_Q_rope; + StrideK stride_K_rope; + StrideO stride_O; + StrideLSE stride_LSE; + StrideLSE stride_PT; + + uint64_t seed = 0; + + int page_size = -1; + int page_count = -1; + + // We allocate Q and C as first latent, then rope + // This means that we offset the pointer by HeadDim_latent to get the rope + // portion + DeviceAllocation block_Q; + DeviceAllocation block_C; + DeviceAllocation block_O; + DeviceAllocation block_seq; + DeviceAllocation block_PT; + DeviceAllocation block_split_kv; + DeviceAllocation block_accum_split_len; + DeviceAllocation block_LSE; + DeviceAllocation block_ref_O; + DeviceAllocation block_ref_LSE; + + ElementAcc scale; + + // + // Methods + // + + bool verify(const ProblemShape& problem_shape) { + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + int page_K = K; + int page_B = B; + if (block_PT.get() != nullptr) { + page_K = page_size; + page_B = page_count; + } + + Tensor mQ_latent = make_tensor(make_gmem_ptr(block_Q.get()), + cute::make_tuple(H, D_latent, B), + stride_Q_latent); + + Tensor mQ_rope = make_tensor(make_gmem_ptr(block_Q.get() + D_latent), + cute::make_tuple(H, D_rope, B), + stride_Q_rope); + + Tensor mC_latent = make_tensor(make_gmem_ptr(block_C.get()), + cute::make_tuple(page_K, D_latent, page_B), + stride_C_latent); + + Tensor mK_rope = make_tensor(make_gmem_ptr(block_C.get() + D_latent), + cute::make_tuple(page_K, D_rope, page_B), + stride_K_rope); + + Tensor mO = make_tensor(make_gmem_ptr(block_ref_O.get()), + cute::make_tuple(H, D_latent, B), + stride_O); + + Tensor mLSE = make_tensor(make_gmem_ptr(block_ref_LSE.get()), + cute::make_tuple(H, B), + stride_LSE); + + Tensor mSeq = make_tensor(make_gmem_ptr(static_cast(block_seq.get())), make_shape(B)); + Tensor mPT = make_tensor(make_gmem_ptr(static_cast(block_PT.get())), make_shape(ceil_div(K, page_size), B), stride_PT); + + fmha_mla_reference(problem_shape, mSeq, mPT, mQ_latent, mQ_rope, mC_latent, mK_rope, mO, mLSE, scale); + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Reference kernel failed. Last CUDA error: " + << cudaGetErrorString(result) << std::endl; + return false; + } + + const double kMaxDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-2; + const double kMeanDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-3; + + // Check if output from CUTLASS kernel and reference kernel are equal or not + double max_diff = 0; + double mean_diff = 0; +#ifdef B2B + reference_rel_diff(block_O, block_ref_O, max_diff, mean_diff); +#else + reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff); +#endif + + bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); + if (! passed_O) { + std::cerr << "failed O: max diff " << max_diff + << " mean " << mean_diff << std::endl; + } + + bool passed_LSE = true; +#ifndef B2B + reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff); + + passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); + if ( ! passed_LSE) { + std::cerr << "failed LSE: max diff " << max_diff + << " mean " << mean_diff << std::endl; + } +#endif + + return passed_O && passed_LSE; + } + + ProblemShape initialize(const Options& options) { + auto problem_shape = cute::make_tuple(TileShapeH{}, options.k, TileShapeD{}, options.b); + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + // the scale is based on the non-absorbed sizes, change as appropriate + // we can't determine this parameter from the info we have, it's an input + int D_non_latent = 128; + scale = static_cast(1.0 / sqrt(1.0 * (D_non_latent + D_rope))); + // Shape (H, D, B) + stride_Q_latent = cute::make_tuple(static_cast(0 + D_latent + D_rope), _1{}, static_cast(H * (0 + D_latent + D_rope))); + stride_Q_rope = stride_Q_latent; + stride_O = cute::make_tuple(static_cast(0 + D_latent), _1{}, static_cast(0 + H * D_latent)); + stride_LSE = cute::make_tuple(_1{}, 0 + H); + + block_Q.reset(static_cast(options.b) * H * (D_latent + D_rope)); + block_O.reset(static_cast(options.b) * H * D_latent); + block_LSE.reset(static_cast(options.b) * H); + block_ref_O.reset(static_cast(options.b) * H * D_latent); + block_ref_LSE.reset(static_cast(options.b) * H); + + if (options.page == -1) { + + stride_C_latent = cute::make_tuple(static_cast(0 + D_latent + D_rope), _1{}, static_cast(options.k) * (D_latent + D_rope)); + stride_K_rope = stride_C_latent; + + block_C.reset(static_cast(options.b) * options.k * (D_latent + D_rope)); + + } + else { + + float spread = options.spread; + int max_K = static_cast((1 + spread) * K); + int min_K = static_cast((1 - spread) * K); + page_size = options.page; + page_count = B * ceil_div(max_K, page_size); + stride_PT = cute::make_stride(_1{}, page_count); + + std::vector host_seq(B); + std::vector host_PT(page_count * B); + + for (int i = 0; i < B; i++) { + int seq = min_K + rand() % (max_K - min_K + 1); + host_seq[i] = seq; + for (int j = 0; j < ceil_div(seq, page_size); j++) { + host_PT[page_count * i + j] = i + j * B; + } + } + + block_seq.reset(host_seq.size()); + block_seq.copy_from_host(host_seq.data(), host_seq.size()); + block_PT.reset(host_PT.size()); + block_PT.copy_from_host(host_PT.data(), host_PT.size()); + + get<1>(problem_shape) = max_K; + + stride_C_latent = cute::make_tuple(static_cast(0 + D_latent + D_rope), _1{}, page_size * static_cast((D_latent + D_rope))); + stride_K_rope = stride_C_latent; + + block_C.reset(page_count * page_size * static_cast((D_latent + D_rope))); + + if (options.is_var_split_kv == true) { + std::vector host_split_kv(B); + for(int i = 0; i < B; ++i) { + auto len = host_seq[i]; + int split = ceil_div(options.max_split_kv, ceil_div(max_K, len)); + host_split_kv[i] = split; + } + block_split_kv.reset(B); + block_split_kv.copy_from_host(host_split_kv.data(), host_split_kv.size()); + } + } + + initialize_block(block_Q, seed + 2023, options.init_style_q); + initialize_block(block_C, seed + 2022, options.init_style_c); + + return problem_shape; + } + + ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + + ProblemShape problem_shape = initialize(options); + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + typename Operation::Arguments arguments{ + problem_shape, + { scale, + block_Q.get(), stride_Q_latent, + block_Q.get() + D_latent, stride_Q_rope, + block_C.get(), stride_C_latent, + block_C.get() + D_latent, stride_K_rope, + block_seq.get(), + block_PT.get(), stride_PT, + page_count, page_size}, + { block_O.get(), + stride_O, + block_LSE.get(), + stride_LSE}, + hw_info, + options.split_kv, + options.is_var_split_kv ? block_split_kv.get() : nullptr + }; + if (options.split_kv < 0 && !options.is_var_split_kv) { + Operation::set_split_kv(arguments); + } + + Operation op; + + ExampleResult example_result; + + example_result.smem_size = Operation::Kernel::SharedStorageSize; + + size_t workspace_size = 0; + workspace_size = Operation::get_workspace_size(arguments); + DeviceAllocation workspace(workspace_size); + + cutlass::Status status = cutlass::Status::kSuccess; + status = op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + std::cerr << "This kernel is not supported. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + + status = op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + // Run + status = op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result = cudaEventCreate(&event); + if (result != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + } + + // Record an event at the start of a series of GEMMs + result = cudaEventRecord(events[0]); + if (result != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + for (int i = 0; i < options.iterations; i++) { + status = op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + } + + // + // Stop profiling loop + // + + // Record an event when the GEMMs are complete + result = cudaEventRecord(events[1]); + if (result != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // Wait for work on the device to complete. + result = cudaEventSynchronize(events[1]); + if (result != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + runtime_ms /= static_cast(options.iterations); + + double flops = 1.0; + flops *= B; + flops *= K; + flops *= H; + flops *= 2.0; + flops *= (2.0 * D_latent + D_rope); + + double bytes_q = sizeof(Element); + bytes_q *= B; + bytes_q *= H; + bytes_q *= (D_latent + D_rope); + double bytes_c = sizeof(Element); + bytes_c *= B; + bytes_c *= options.k; // K may be max_K here + bytes_c *= (D_latent + D_rope); + double bytes_o = sizeof(ElementOut); + bytes_o *= B; + bytes_o *= H; + bytes_o *= D_latent; + double bytes = bytes_q + bytes_c + bytes_o; + + double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/); + double tbytes_s = bytes * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/); + example_result.tflops_tc_s = tflops_s; + example_result.tbytes_s = tbytes_s; + example_result.runtime_ms = runtime_ms; + + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // Verify that the result is correct + bool passed = true; + if (options.verify) { + passed = verify(problem_shape); + if (passed) example_result.verified = true; + } + + if (!passed) { + std::cerr << "Reference check failed" << std::endl; + return example_result; + } + + example_result.passed = true; + + return example_result; + } + +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to print a description of the example run and its result +void print_result(const std::string& description, ExampleResult result, bool verbose) { + std::ios fmt(nullptr); + fmt.copyfmt(std::cout); + std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] "); + std::cout << std::setw(32) << std::left << description; + std::cout.copyfmt(fmt); + std::cout << " : " << result.tflops_tc_s << " TFLOPS/s " << result.tbytes_s << " TB/s" << std::endl; + if (verbose) { + std::cout << " t=" << result.runtime_ms * 1e3 << " us, " + "smem=" << result.smem_size << "b" << std::endl; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void run_mla(Options const & options, cutlass::KernelHardwareInfo const& hw_info) { + auto run = [&](auto shape, const char* name, auto... kernel_options) { + if ((! options.kernel_filter.empty()) && (! std::regex_search(name, std::basic_regex(options.kernel_filter)))) { + return; + } + Runner runner; + auto result = runner.run(options, hw_info); + print_result(name, result, options.verbose); + }; + + using NumHeads = _128; + using HeadDimLatent = _512; + using HeadDim = Shape; + + std::cout << "###### B " << options.b << " MLA H " << 0 + NumHeads{} << " "; + std::cout << "D_rope " << 0 + get<1>(HeadDim{}) << " D_latent " << 0 + get<0>(HeadDim{}) << " "; + std::cout << "Q 1 K " << options.k << " Gen None "; + std::cout << "Split " << options.split_kv << " Gen None "; + std::cout << "#SM " << hw_info.sm_count << std::endl; + + using Blocking = _128; + std::string name = std::to_string((int) NumHeads{}) + "x" + std::to_string((int) Blocking{}); + std::string individual = " individual"; + std::string persistent = " persistent"; +#if FP8 + name += " fp8"; + // Persistent Tile Scheduler + run(Shape{}, (name + persistent).c_str(), IsPersistent{}); + // Individual Tile Scheduler + run(Shape{}, (name + individual).c_str(), IsPersistent{}); +#elif FP16 + name += " fp16"; + // Persistent Tile Scheduler + run(Shape{}, (name + persistent).c_str(), IsPersistent{}); + // Individual Tile Scheduler + run(Shape{}, (name + individual).c_str(), IsPersistent{}); +#endif +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + + +int main_single(int argc, char const **args) { + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (__CUDACC_VER_MAJOR__ < 12 || props.major != 10) { + std::cout + << "This example requires a GPU of NVIDIA's Blackwell Architecture " + << "(compute capability major 10) and CUDA 12.8 or greater.\n"; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + if (options.sm_count == 0) { + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + else { + hw_info.sm_count = options.sm_count; + } + + run_mla(options, hw_info); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + std::vector full_arguments(args, args + argc); + + int result = 0; + + bool recursed = false; + for (size_t i = 1; i < full_arguments.size(); i++) { + if (full_arguments[i].find(',') != std::string::npos) { + auto arg = full_arguments[i]; + size_t eq_pos = arg.find('='); + std::string prefix = eq_pos == std::string::npos ? "" : arg.substr(0, eq_pos+1); + std::string rest = eq_pos == std::string::npos ? arg : arg.substr(eq_pos+1); + for (;;) { + size_t comma_pos = rest.find(','); + std::string current = rest.substr(0, comma_pos); + full_arguments[i] = prefix + current; + std::vector next_args; + for (auto& elem : full_arguments) { next_args.push_back(elem.data()); } + main(argc, next_args.data()); + if (comma_pos == std::string::npos) break; + rest = rest.substr(comma_pos+1); + } + recursed = true; + break; + } + } + + if (! recursed) { + main_single(argc, args); + } + + return result; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/77_blackwell_fmha/CMakeLists.txt b/examples/77_blackwell_fmha/CMakeLists.txt index 90b4738760..f04ebe417b 100644 --- a/examples/77_blackwell_fmha/CMakeLists.txt +++ b/examples/77_blackwell_fmha/CMakeLists.txt @@ -28,12 +28,14 @@ set_property( - SOURCE 77_blackwell_fmha.cu - PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0") - -set_property( - SOURCE 77_blackwell_fmha_gen.cu - PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0") + SOURCE + 77_blackwell_fmha.cu + 77_blackwell_fmha_gen.cu + 77_blackwell_mla.cu + 77_blackwell_fmha_bwd.cu + PROPERTY + COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0" +) set(TEST_BASIC --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=no) set(TEST_CAUSAL --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal) @@ -48,58 +50,98 @@ set(TEST_GEN_GQA --b=2 --h=4 --h_k=2 --k=512 --d=64 --verify) set(TEST_GEN_REMAP --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --remap) set(TEST_GEN_CACHEONLY --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --cache-only) -if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang"))) - if (CUTLASS_NVCC_ARCHS MATCHES 100a) - cutlass_example_add_executable( - 77_blackwell_fmha_fp8 - 77_blackwell_fmha.cu - TEST_COMMAND_OPTIONS - TEST_BASIC - # TEST_CAUSAL - # TEST_VARLEN - # TEST_HDIM64 - # TEST_GQA) - ) - target_include_directories(77_blackwell_fmha_fp8 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) - target_compile_definitions(77_blackwell_fmha_fp8 PRIVATE FP8) +set(TEST_MLA_BASIC --b=1 --k=512 --verify) - cutlass_example_add_executable( - 77_blackwell_fmha_gen_fp8 - 77_blackwell_fmha_gen.cu - TEST_COMMAND_OPTIONS - TEST_GEN_BASIC - # TEST_GEN_VARLEN - # TEST_GEN_HDIM64 - # TEST_GEN_GQA - # TEST_GEN_REMAP - # TEST_GEN_CACHEONLY) - ) - target_include_directories(77_blackwell_fmha_gen_fp8 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) - target_compile_definitions(77_blackwell_fmha_gen_fp8 PRIVATE FP8) +if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC_ARCHS MATCHES 100a)) - cutlass_example_add_executable( - 77_blackwell_fmha_fp16 - 77_blackwell_fmha.cu - TEST_COMMAND_OPTIONS - TEST_BASIC - # TEST_CAUSAL - # TEST_VARLEN - # TEST_HDIM64 - # TEST_GQA) - ) - target_include_directories(77_blackwell_fmha_fp16 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + foreach(PREC fp8 fp16) + string(TOUPPER "${PREC}" PREC_MACRO) - cutlass_example_add_executable( - 77_blackwell_fmha_gen_fp16 - 77_blackwell_fmha_gen.cu - TEST_COMMAND_OPTIONS - TEST_GEN_BASIC - # TEST_GEN_VARLEN - # TEST_GEN_HDIM64 - # TEST_GEN_GQA - # TEST_GEN_REMAP - # TEST_GEN_CACHEONLY) - ) - target_include_directories(77_blackwell_fmha_gen_fp16 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) - endif() + cutlass_example_add_executable( + 77_blackwell_fmha_${PREC} + 77_blackwell_fmha.cu + TEST_COMMAND_OPTIONS + TEST_BASIC + # TEST_CAUSAL + # TEST_VARLEN + # TEST_HDIM64 + # TEST_GQA) + ) + target_include_directories(77_blackwell_fmha_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + target_compile_definitions(77_blackwell_fmha_${PREC} PRIVATE ${PREC_MACRO}) + + cutlass_example_add_executable( + 77_blackwell_fmha_gen_${PREC} + 77_blackwell_fmha_gen.cu + TEST_COMMAND_OPTIONS + TEST_GEN_BASIC + # TEST_GEN_VARLEN + # TEST_GEN_HDIM64 + # TEST_GEN_GQA + # TEST_GEN_REMAP + # TEST_GEN_CACHEONLY) + ) + target_include_directories(77_blackwell_fmha_gen_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + target_compile_definitions(77_blackwell_fmha_gen_${PREC} PRIVATE ${PREC_MACRO}) + + cutlass_example_add_executable( + 77_blackwell_mla_2sm_${PREC} + 77_blackwell_mla.cu + TEST_COMMAND_OPTIONS + TEST_MLA_BASIC + ) + target_include_directories(77_blackwell_mla_2sm_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + target_compile_definitions(77_blackwell_mla_2sm_${PREC} PRIVATE ${PREC_MACRO}) + target_compile_options(77_blackwell_mla_2sm_${PREC} PRIVATE -Xptxas -v) + + cutlass_example_add_executable( + 77_blackwell_mla_2sm_cpasync_${PREC} + 77_blackwell_mla.cu + TEST_COMMAND_OPTIONS + TEST_MLA_BASIC + ) + target_include_directories(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + target_compile_definitions(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${PREC_MACRO} CPASYNC) + target_compile_options(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE -Xptxas -v) + + cutlass_example_add_executable( + 77_blackwell_mla_b2b_2sm_${PREC} + 77_blackwell_mla.cu + TEST_COMMAND_OPTIONS + TEST_MLA_BASIC + ) + target_include_directories(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + target_compile_definitions(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE ${PREC_MACRO} B2B) + target_compile_options(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE -Xptxas -v) + + cutlass_example_add_executable( + 77_blackwell_fmha_bwd_${PREC} + 77_blackwell_fmha_bwd.cu + TEST_COMMAND_OPTIONS + TEST_BASIC + # TEST_GEN_VARLEN + # TEST_GEN_HDIM64 + # TEST_GEN_GQA + # TEST_GEN_REMAP + # TEST_GEN_CACHEONLY) + ) + target_include_directories(77_blackwell_fmha_bwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + target_compile_definitions(77_blackwell_fmha_bwd_${PREC} PRIVATE ${PREC_MACRO}) + target_compile_options(77_blackwell_fmha_bwd_${PREC} PRIVATE -Xptxas -v) + + cutlass_example_add_executable( + 77_blackwell_fmha_bwd_sat_${PREC} + 77_blackwell_fmha_bwd.cu + TEST_COMMAND_OPTIONS + TEST_BASIC + # TEST_GEN_VARLEN + TEST_GEN_HDIM64 + # TEST_GEN_GQA + # TEST_GEN_REMAP + # TEST_GEN_CACHEONLY) + ) + target_include_directories(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + target_compile_definitions(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${PREC_MACRO} SKIP_ATOMIC) + target_compile_options(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE -Xptxas -v) + endforeach() endif() diff --git a/examples/77_blackwell_fmha/README.md b/examples/77_blackwell_fmha/README.md index 2f4c9c760b..a1536dc8b8 100644 --- a/examples/77_blackwell_fmha/README.md +++ b/examples/77_blackwell_fmha/README.md @@ -22,6 +22,39 @@ The `apply_mask` function is called with the accumulator of the first GEMM and t It is well-suited for applying masks or activations. More complex fusions that require memory loads would require modifying the mainloop collective to orchestrate the load via TMA. +# FMHA for Blackwell: Backward + +This sample provides code for fused multi-head attention backward pass. +It supports HeadDims of 64 and 128, and fp8, fp16, and bf16 input data types. +The blocking in sequence length Q and K is 128, loads are done via TMA. +We support causal masking. +The structure of this code is very similar to the forward pass, and the techniques are analogous. + +There are three kernels to compute backwards: +1. `FmhaKernelBwdSumOdO` to compute the sum of the outer product of O and dO. +3. `Sm100FmhaBwdKernelTmaWarpSpecialized` to compute the backward pass. +2. `FmhaKernelBwdConvert` to convert the dQ from fp32 to the final output precision. + +`Sm100FmhaBwdKernelTmaWarpSpecialized` is the main point of this sample, as it demonstrates how to use tensor cores to achieve a high performance fused kernel. + +# MLA Inference for Blackwell + +This sample provides code for fused multi-head latent attention inference in +the weight-absorbed regime, i.e. for latent head dim 512, and rope head dim 64. +It supports fp16, bf16, and fp8 input and output types. + +To accomodate the large output accumulator due to the large latent head dimension, +the sample demonstrates how to leverage 2Sm Blackwell tensor cores. + +Loading can be done via TMA (either without paging or with page size 128), or using `cp.async` +for support of any power-of-two page size less than or equal to 128. +With paging, the code also supports variable sequence length. + +The approach of this implementation is to reuse the selection logic of the collective gemm builder and recombine the result into an MLA kernel. + +The example builds six binaries, showcasing TMA and `cp.async` usage, as well as a back-to-back gemm (essentially turning the softmax into a no-op) for fp8 and fp16. +For detailed information on how to invoke them, check out either the tests in `CMakeLists.txt` or the `--help` for them. + # Copyright Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. diff --git a/examples/77_blackwell_fmha/common/pow_2.hpp b/examples/77_blackwell_fmha/common/pow_2.hpp new file mode 100644 index 0000000000..eca93250f4 --- /dev/null +++ b/examples/77_blackwell_fmha/common/pow_2.hpp @@ -0,0 +1,92 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +#include + +namespace cutlass::fmha { + +struct Pow2 { + int n; + int log2_n; + + explicit CUTE_DEVICE Pow2(int n) : n(n) { +#ifdef __CUDA_ARCH__ + log2_n = __ffs(n) - 1; +#endif + } + + template + CUTE_HOST_DEVICE T operator *(T const& b) const { + return n * b; + } + + template + CUTE_HOST_DEVICE auto operator *(Int const&) const { + if constexpr (N & (N - 1) == 0) { + return Pow2{n * N}; + } + return n * N; + } + +}; + +template +CUTE_HOST_DEVICE auto operator/(T const& a, Pow2 const& b) { + return a >> b.log2_n; +} + +template +CUTE_HOST_DEVICE auto operator%(T const& a, Pow2 const& b) { + return a & (b.n - 1); +} + +template +CUTE_HOST_DEVICE bool operator<(T const& a, Pow2 const& b) { + return a < b.n; +} + +CUTE_HOST_DEVICE void print(Pow2 const& a) { + printf("2^%d", a.log2_n); +} + +} // end namespace cutlass::fmha + +namespace cute { + +template <> +struct is_integral : true_type {}; + +} // end namespace cute diff --git a/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp b/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp new file mode 100644 index 0000000000..80fcdf9fdf --- /dev/null +++ b/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp @@ -0,0 +1,320 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cute/tensor.hpp" + +#include "../device/fmha.hpp" +#include "../kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp" +#include "../kernel/fmha_kernel_bwd_sum_OdO.hpp" +#include "../kernel/fmha_kernel_bwd_convert.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::fmha::device { + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template +class Sm100FmhaBwd { +public: + /// Argument structure: User API + struct Arguments { + // Q K D HB + cute::tuple> problem_size; + + const Element* ptr_Q; + cute::tuple> stride_Q; + const Element* ptr_K; + cute::tuple> stride_K; + const Element* ptr_V; + cute::tuple> stride_V; + + const Element* ptr_O; + cute::tuple> stride_O; + const ElementAccumulator* ptr_LSE; + cute::tuple> stride_LSE; + + const Element* ptr_dO; + cute::tuple> stride_dO; + + Element* ptr_dQ; + cute::tuple> stride_dQ; + Element* ptr_dK; + cute::tuple> stride_dK; + Element* ptr_dV; + cute::tuple> stride_dV; + + ElementAccumulator softmax_scale; + + cutlass::KernelHardwareInfo hw_info; + }; + + using OperationSumOdO = cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::FmhaKernelBwdSumOdO + >; + using OperationConvert = cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::FmhaKernelBwdConvert + >; + + using Operation = cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized + >; + using Kernel = typename Operation::Kernel; + + struct Params { + OperationSumOdO op_sum_OdO; + Operation op; + OperationConvert op_convert; + ElementAccumulator* dQ_acc; + size_t dQ_acc_size; + }; + +private: + Params params_; + + static typename OperationSumOdO::Arguments to_sum_OdO_arguments( + Arguments const& args, + ElementAccumulator* sum_odo = nullptr, + ElementAccumulator* scaled_lse = nullptr) { + using namespace cute; + auto [Q, K, D, HB] = args.problem_size; + auto [H, B] = HB; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + auto stride_sum_OdO = make_stride(_1{}, make_stride(Q, Q*H)); + auto stride_scaled_lse = make_stride(_1{}, make_stride(Q, Q*H)); + auto log2_e = log2f(expf(1.0f)); + return typename OperationSumOdO::Arguments { + args.problem_size, + args.ptr_O, args.stride_O, + args.ptr_dO, args.stride_dO, + sum_odo, stride_sum_OdO, + args.ptr_LSE, args.stride_LSE, + scaled_lse, stride_scaled_lse, + -1.0f, -log2_e + }; + } + + static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) { + using namespace cute; + auto [Q, K, D, HB] = args.problem_size; + auto [H, B] = HB; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + auto stride_src_dQ = make_stride(D, _1{}, make_stride(D*Q, D*Q*H)); + return typename OperationConvert::Arguments { + args.problem_size, + src, stride_src_dQ, + nullptr, stride_src_dQ, + nullptr, stride_src_dQ, + args.ptr_dQ, args.stride_dQ, + nullptr, args.stride_dK, + nullptr, args.stride_dV, + args.softmax_scale + }; + } + + static typename Operation::Arguments to_bwd_arguments( + Arguments const& args, + ElementAccumulator* sum_OdO = nullptr, cute::tuple> const& stride_sum_OdO = {}, + ElementAccumulator* scaled_lse = nullptr, cute::tuple> const& stride_scaled_lse = {}, + ElementAccumulator* dQ_acc = nullptr, cute::tuple> const& stride_dQ = {}) { + return typename Operation::Arguments{ + args.problem_size, + { args.ptr_Q, args.stride_Q, + args.ptr_K, args.stride_K, + args.ptr_V, args.stride_V, + args.ptr_dO, args.stride_dO, + scaled_lse, stride_scaled_lse, + sum_OdO, stride_sum_OdO, + dQ_acc, stride_dQ, + args.softmax_scale }, + { args.ptr_dK, args.stride_dK, + args.ptr_dV, args.stride_dV }, + args.hw_info + }; + } + +public: + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + Status status = Status::kSuccess; + + status = OperationSumOdO::can_implement(to_sum_OdO_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + status = OperationConvert::can_implement(to_convert_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + status = Operation::can_implement(to_bwd_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + auto [Q, K, D, HB] = args.problem_size; + auto [H, B] = HB; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + size_t workspace_bytes = 0; + // OdO vector + workspace_bytes += B*H*Q * sizeof(ElementAccumulator); + // scaled LSE vector + workspace_bytes += B*H*Q * sizeof(ElementAccumulator); + // FP32 versions of outputs that are churned (start off with Q only) + workspace_bytes += B*H*Q*D * sizeof(ElementAccumulator); + return workspace_bytes; + } + + /// Initializes state from arguments. + Status + initialize_split(Arguments const& args, void* workspace_dQ, void* workspace_sum_OdO, void* workspace_scaled_lse, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ=" + << workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null")); + + auto [Q, K, D, HB] = args.problem_size; + auto [H, B] = HB; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + ElementAccumulator* sum_OdO = reinterpret_cast(workspace_sum_OdO); + ElementAccumulator* scaled_lse = reinterpret_cast(workspace_scaled_lse); + ElementAccumulator* dQ_acc = reinterpret_cast(workspace_dQ); + params_.dQ_acc = dQ_acc; + params_.dQ_acc_size = B*H*Q*D * sizeof(ElementAccumulator); + auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO, scaled_lse); + auto args_convert = to_convert_arguments(args, dQ_acc); + params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream); + params_.op_convert.initialize(args_convert, nullptr, stream); + auto args_bwd = to_bwd_arguments( + args, sum_OdO, args_sum_OdO.stride_sum_OdO, + scaled_lse, args_sum_OdO.stride_scaled_lse, + dQ_acc, args_convert.stride_src_dQ + ); + params_.op.initialize(args_bwd, nullptr, stream); + + return Status::kSuccess; + } + + /// Initializes state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("Universal::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + auto [Q, K, D, HB] = args.problem_size; + auto [H, B] = HB; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + char* workspace_chr = reinterpret_cast(workspace); + ElementAccumulator* sum_OdO = reinterpret_cast(workspace_chr); + workspace_chr += B*H*Q * sizeof(ElementAccumulator); + ElementAccumulator* scaled_lse = reinterpret_cast(workspace_chr); + workspace_chr += B*H*Q * sizeof(ElementAccumulator); + ElementAccumulator* dQ_acc = reinterpret_cast(workspace_chr); + return initialize_split(args, dQ_acc, sum_OdO, scaled_lse, stream); + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("FmhaDeviceBwd::run()"); + + Status result = Status::kSuccess; + result = params.op_sum_OdO.run(stream); + if (result != Status::kSuccess) { + return result; + } + + auto cuda_result = cudaMemsetAsync(params.dQ_acc, 0, params.dQ_acc_size, stream); + if (cuda_result != cudaSuccess) { + return Status::kErrorInternal; + } + + result = params.op.run(stream); + if (result != Status::kSuccess) { + return result; + } + + result = params.op_convert.run(stream); + if (result != Status::kSuccess) { + return result; + } + + return Status::kSuccess; + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/77_blackwell_fmha/device/sm100_mla.hpp b/examples/77_blackwell_fmha/device/sm100_mla.hpp new file mode 100644 index 0000000000..4e09809007 --- /dev/null +++ b/examples/77_blackwell_fmha/device/sm100_mla.hpp @@ -0,0 +1,357 @@ +/*************************************************************************************************** + * Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief An universal device layer for cutlass 3.x-style kernels. +*/ + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" + +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + +#include "kernel/sm100_fmha_mla_tma_warpspecialized.hpp" +#include "kernel/sm100_fmha_mla_reduction.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::fmha::device { + +using namespace cute; +using namespace cutlass::fmha::kernel; + + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template< + class Kernel_ +> +class MLA { +public: + + using Kernel = Kernel_; + + using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel< + typename Kernel::ElementOut, + typename Kernel::ElementAcc, + typename Kernel::ElementAcc, + Kernel::TileShapeH::value, + Kernel::TileShapeL::value, + 256 /*Max split*/ + >; + + /// Argument structure: User API + using KernelArguments = typename Kernel::Arguments; + using ReductionArguments = typename ReductionKernel::Arguments; + + using Arguments = KernelArguments; + + /// Argument structure: Kernel API + using KernelParams = typename Kernel::Params; + using ReductionParams = typename ReductionKernel::Params; + struct Params { + KernelParams fmha_params; + ReductionParams reduction_params; + }; + +private: + + /// Kernel API parameters object + Params params_; + + bool is_initialized(bool set = false) { + static bool initialized = false; + if (set) initialized = true; + return initialized; + } + + static ReductionArguments to_reduction_args(Arguments const& args) { + auto [H, K, D, B] = args.problem_shape; + return ReductionArguments{ + nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse, + args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq, + args.ptr_split_kv, Kernel::TileShapeS::value + }; + } + +public: + + /// Access the Params structure + Params const& params() const { + return params_; + } + + static void set_split_kv (KernelArguments& args) { + if (args.split_kv >= 1) return; + auto [H, K, D, B] = args.problem_shape; + int sm_count = args.hw_info.sm_count; + int max_splits = ceil_div(K, 128); + int sms_per_batch = max(1, sm_count / B); + int split_heur = min(max_splits, sms_per_batch); + int waves = ceil_div(B * split_heur, sm_count); + int k_waves = ceil_div(max_splits, split_heur); + int split_wave_aware = ceil_div(max_splits, k_waves); + args.split_kv = split_wave_aware; + } + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (! Kernel::can_implement(args)) { + return Status::kInvalid; + } + if (! ReductionKernel::can_implement(to_reduction_args(args))) { + return Status::kInvalid; + } + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + workspace_bytes += Kernel::get_workspace_size(args); + workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args)); + return workspace_bytes; + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = Kernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + device_kernel, + Kernel::MaxThreadsPerBlock, + smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("MLA::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = Kernel::initialize_workspace(args, workspace, stream); + if (status != Status::kSuccess) { + return status; + } + status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream); + if (status != Status::kSuccess) { + return status; + } + KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace); + + ReductionArguments reduction_args = to_reduction_args(args); + if (reduction_args.split_kv > 1) { + reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc; + reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc; + } + ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace); + // Initialize the Params structure + params_ = Params {kernel_params, reduction_params}; + + if (is_initialized()) return Status::kSuccess; + + // account for dynamic smem capacity if needed + // no dynamic smem is needed for reduction kernel + int smem_size = Kernel::SharedStorageSize; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + is_initialized(true); + + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + auto fmha_params = Kernel::to_underlying_arguments(args, workspace); + + ReductionArguments reduction_args = to_reduction_args(args); + if (reduction_args.split_kv > 1) { + reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc; + reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc; + } + ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace); + // Initialize the Params structure + params_ = Params {fmha_params, reduction_params}; + + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("MLA::run()"); + dim3 const block = Kernel::get_block_shape(); + dim3 const grid = Kernel::get_grid_shape(params.fmha_params); + + // configure smem size and carveout + int smem_size = Kernel::SharedStorageSize; + + Status launch_result; + // Use extended launch API only for mainloops that use it + if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) { + dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}), + cute::size<1>(typename Kernel::ClusterShape{}), + cute::size<2>(typename Kernel::ClusterShape{})); + void const* kernel = (void const*) device_kernel; + void* kernel_params[] = {¶ms.fmha_params}; + launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params); + } + else { + launch_result = Status::kSuccess; + device_kernel<<>>(params.fmha_params); + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess != result or Status::kSuccess != launch_result) { + //return Status::kSuccess; + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + if (params.reduction_params.split_kv > 1) { + // launch reduction kernel + dim3 const block = ReductionKernel::get_block_shape(); + dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params); + device_kernel<<>>(params.reduction_params); + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + else { + return Status::kSuccess; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + return run(args, workspace, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr) { + return run(params_, stream); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp b/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp new file mode 100644 index 0000000000..c2618bcb70 --- /dev/null +++ b/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp @@ -0,0 +1,146 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template +struct FmhaKernelBwdConvert { + + struct Arguments { + tuple> problem_size; + + const ElementAcc* ptr_src_dQ; + tuple> stride_src_dQ; + const ElementAcc* ptr_src_dK; + tuple> stride_src_dK; + const ElementAcc* ptr_src_dV; + tuple> stride_src_dV; + + Element* ptr_dest_dQ; + tuple> stride_dest_dQ; + Element* ptr_dest_dK; + tuple> stride_dest_dK; + Element* ptr_dest_dV; + tuple> stride_dest_dV; + + ElementAcc scale = 1.0; + }; + + using Params = Arguments; + + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int SharedStorageSize = 0; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = 128; + using ArchTag = cutlass::arch::Sm90; + + static const int kBlockSeq = 8; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static const int kNumThreadsD = 16; + static const int kNumThreadsSeq = MaxThreadsPerBlock / kNumThreadsD; + static const int kElementsPerLoad = 4; + + static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq; + + static bool can_implement(Arguments const& args) { + return get<2>(args.problem_size) % kElementsPerLoad == 0; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(size<3,0>(params.problem_size), size<3,1>(params.problem_size), ceil_div(std::max(size<0>(params.problem_size), size<1>(params.problem_size)), kBlockSeq)); + return grid; + } + + static dim3 get_block_shape() { + dim3 block(kNumThreadsD, kNumThreadsSeq, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return args; + } + + template + CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, int count) { + auto ptr_src_bh = ptr_src + get<2,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y; + auto ptr_dest_bh = ptr_dest + get<2,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y; + + for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) { + int idx_s = idx_s_t + kBlockSeq * blockIdx.z; + if (idx_s >= count) continue; + auto ptr_src_bhs = ptr_src_bh + idx_s * get<0>(stride_src); + auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<0>(stride_dest); + + for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) { + ElementAcc value_src[kElementsPerLoad]; + Element value_dest[kElementsPerLoad]; + + using VecSrc = uint_bit_t * kElementsPerLoad>; + using VecDest = uint_bit_t * kElementsPerLoad>; + *reinterpret_cast(value_src) = *reinterpret_cast(&ptr_src_bhs[idx_d]); + + for (int v = 0; v < kElementsPerLoad; v++) { + value_dest[v] = static_cast(params.scale * value_src[v]); + } + + *reinterpret_cast(&ptr_dest_bhs[idx_d]) = *reinterpret_cast(value_dest); + } + } + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { + if (params.ptr_src_dQ != nullptr) { + copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_size)); + } + if (params.ptr_src_dK != nullptr) { + copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_size)); + } + if (params.ptr_src_dV != nullptr) { + copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_size)); + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp b/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp new file mode 100644 index 0000000000..44080e2d10 --- /dev/null +++ b/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp @@ -0,0 +1,151 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template +struct FmhaKernelBwdSumOdO { + + struct Arguments { + cute::tuple> problem_size; + + const Element* ptr_O; + cute::tuple> stride_O; + const Element* ptr_dO; + cute::tuple> stride_dO; + + ElementAcc* ptr_sum_OdO; + cute::tuple> stride_sum_OdO; + + const ElementAcc* ptr_lse = nullptr; + cute::tuple> stride_lse; + + ElementAcc* ptr_scaled_lse = nullptr; + cute::tuple> stride_scaled_lse; + + ElementAcc sum_odo_scale = 1.0; + ElementAcc lse_scale = 1.0; + }; + + using Params = Arguments; + + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int SharedStorageSize = 0; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = 128; + using ArchTag = cutlass::arch::Sm100; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static const int kBlockQ = 16; + + static const int kNumThreadsD = 8; + static const int kNumThreadsQ = MaxThreadsPerBlock / kNumThreadsD; + static const int kElementsPerLoad = 2; + + static const int kIterationsQ = kBlockQ / kNumThreadsQ; + + static bool can_implement(Arguments const& args) { + return get<2>(args.problem_size) % kElementsPerLoad == 0; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(ceil_div(size<0>(params.problem_size), kBlockQ), size<3,0>(params.problem_size), size<3,1>(params.problem_size)); + return grid; + } + + static dim3 get_block_shape() { + dim3 block(kNumThreadsD, kNumThreadsQ, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return args; + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { + auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O); + auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO); + auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO); + auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse); + auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse); + + CUTLASS_PRAGMA_UNROLL + for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) { + int idx_q = idx_q_t + kBlockQ * blockIdx.x; + if (idx_q >= get<0>(params.problem_size)) continue; + ElementAcc acc = 0; + auto ptr_O_bhq = ptr_O_bh + idx_q * get<0>(params.stride_O); + auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<0>(params.stride_dO); + auto ptr_sum_OdO_bhq = ptr_sum_OdO_bh + idx_q * get<0>(params.stride_sum_OdO); + auto ptr_lse_bhq = ptr_lse_bh + idx_q * get<0>(params.stride_lse); + auto ptr_scaled_lse_bhq = ptr_scaled_lse_bh + idx_q * get<0>(params.stride_scaled_lse); + + for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) { + Element value_O[kElementsPerLoad]; + Element value_dO[kElementsPerLoad]; + + using Vec = uint_bit_t * kElementsPerLoad>; + *reinterpret_cast(value_O) = *reinterpret_cast(&ptr_O_bhq[idx_d]); + *reinterpret_cast(value_dO) = *reinterpret_cast(&ptr_dO_bhq[idx_d]); + + for (int v = 0; v < kElementsPerLoad; v++) { + acc += value_O[v] * value_dO[v]; + } + } + + for (int i = 1; i < kNumThreadsD; i *= 2) { + acc += __shfl_xor_sync((uint32_t)-1, acc, i, kNumThreadsD); + } + + if (threadIdx.x == 0) { + *ptr_sum_OdO_bhq = params.sum_odo_scale * acc; + if (params.ptr_scaled_lse) { + *ptr_scaled_lse_bhq = params.lse_scale * *ptr_lse_bhq; + } + } + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp new file mode 100644 index 0000000000..e1bd43d5e5 --- /dev/null +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp @@ -0,0 +1,1699 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/arch/simd_sm100.hpp" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "collective/fmha_common.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cutlass::fmha::collective; + +using namespace cute; + +template< + class Element, + class ElementAcc, + class TileShape, + class Mask +> +struct Sm100FmhaBwdKernelTmaWarpSpecialized { + + using TileShapeQ = decltype(get<0>(TileShape{})); + static_assert(std::is_same_v, "tile shape K must be 128"); + using TileShapeK = decltype(get<1>(TileShape{})); + static_assert(std::is_same_v, "tile shape K must be 128"); + using TileShapeDQK = decltype(get<2>(TileShape{})); + using TileShapeDVO = decltype(get<2>(TileShape{})); + + using TmemAllocator = cute::TMEM::Allocator1Sm; + struct TmemAllocation { + static constexpr uint32_t kDK = 0; // TileShapeK x TileShapeDQK x acc + static constexpr uint32_t kDV = kDK + TileShapeDQK{}; // TileShapeK x TileShapeDVO x acc + static constexpr uint32_t kDQ = kDV + TileShapeDVO{}; // TileShapeQ x TileShapeDQK x acc + static constexpr uint32_t kDP = kDQ; // TileShapeK x TileShapeQ x inp + static constexpr uint32_t kS = kDQ + max(TileShapeQ{}, TileShapeDQK{}); + static constexpr uint32_t kP = kS; + static constexpr uint32_t kTotal = kS + TileShapeQ{}; + }; + + static_assert( + static_cast(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns, + "using too much tmem" + ); + + enum class WarpRole { + Empty = 0x0, Load = 0x1, Mma = 0x2, Compute = 0x3, Reduce = 0x4 + }; + + static constexpr unsigned long long kWarpAssignment = 0x12'3333'3333'4444ull; + static constexpr int kNumComputeWarps = 8; + static constexpr int kNumReduceWarps = 4; + CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) { + return static_cast((kWarpAssignment >> (4 * warp_idx)) & 0xF); + } + + struct RegisterAllocation { + static constexpr int kWarpgroup0 = 160-8; + static constexpr int kWarpgroup1 = 128; + static constexpr int kWarpgroup2 = 96; + static constexpr int kReduce = kWarpgroup0; + static constexpr int kCompute = kWarpgroup1; + static constexpr int kMma = kWarpgroup2; + static constexpr int kEmpty = kWarpgroup2; + static constexpr int kLoad = kWarpgroup2; + + static_assert(kWarpgroup0 + 2 * kWarpgroup1 + kWarpgroup2 <= 512); + }; + + using ArchTag = cutlass::arch::Sm100; + + using ClusterShape = Shape<_1, _1, _1>; + using Schedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100; + + static constexpr int MinBlocksPerMultiprocessor = 1; + static constexpr int kNumWarps = kNumComputeWarps + kNumReduceWarps + 4; + static constexpr int MaxThreadsPerBlock = NumThreadsPerWarp * kNumWarps; + + static constexpr int Alignment = 128 / sizeof_bits_v; + static constexpr int kStages = 2; + + using TensorStrideContiguousK = Stride>; + using TensorStrideContiguousMN = Stride<_1, int, Stride>; + + // compute S + using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousK, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeKQ = typename CollectiveMmaKQ::TileShape; + using TiledMmaKQ = typename CollectiveMmaKQ::TiledMma; + + // compute dP + using CollectiveMmaVDO = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousK, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeVDO = typename CollectiveMmaVDO::TileShape; + using TiledMmaVDO = typename CollectiveMmaVDO::TiledMma; + + // compute dV + using CollectiveMmaPDO = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // needs to match ordering of S calculation + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapePDO = typename CollectiveMmaPDO::TileShape; + using TiledMmaPDO = decltype(to_tiled_mma_sm100_ts(typename CollectiveMmaPDO::TiledMma{})); + + // compute dK + using CollectiveMmaDSQ = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // somewhat arbitrary since we dump to smem, need to agree with the next one + Element, TensorStrideContiguousK , Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDSQ = typename CollectiveMmaDSQ::TileShape; + using TiledMmaDSQ = typename CollectiveMmaDSQ::TiledMma; + + // compute dQ + using CollectiveMmaDSK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // somewhat arbitrary since we dump to smem, need to agree with the previous one + Element, TensorStrideContiguousMN, Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDSK = typename CollectiveMmaDSK::TileShape; + using TiledMmaDSK = typename CollectiveMmaDSK::TiledMma; + + // pipelines are named Pipeline + static constexpr int kStagesComputeSmem = 1; + using PipelineLoadMmaQ = PipelineTmaUmmaAsync<2, ClusterShape>; + using PipelineLoadMmaDO = PipelineTmaUmmaAsync<1, ClusterShape>; + using PipelineLoadComputeLSE = PipelineAsync<1>; + using PipelineLoadComputeSumOdO = PipelineAsync<1>; + using PipelineMmaComputeS = PipelineUmmaAsync<1>; + using PipelineMmaComputeDP = PipelineUmmaAsync<1>; + using PipelineMmaReduceDQ = PipelineUmmaAsync<1>; + using PipelineComputeMmaP = PipelineUmmaConsumerAsync<1>; + using PipelineComputeMmaDS = PipelineUmmaConsumerAsync; + using PipelineMmaComputeDKDV = PipelineUmmaAsync<2>; + static constexpr int kStagesReduceTmaStore = 2; + using PipelineReduceTmaStore = PipelineTmaStore; + + struct PipelineStorage { + alignas(16) typename PipelineLoadMmaQ::SharedStorage load_mma_q; + alignas(16) typename PipelineLoadMmaDO::SharedStorage load_mma_do; + alignas(16) typename PipelineLoadComputeLSE::SharedStorage load_compute_lse; + alignas(16) typename PipelineLoadComputeSumOdO::SharedStorage load_compute_sum_odo; + alignas(16) typename PipelineMmaComputeS::SharedStorage mma_compute_s; + alignas(16) typename PipelineMmaComputeDP::SharedStorage mma_compute_dp; + alignas(16) typename PipelineMmaReduceDQ::SharedStorage mma_reduce_dq; + alignas(16) typename PipelineComputeMmaP::SharedStorage compute_mma_p; + alignas(16) typename PipelineComputeMmaDS::SharedStorage compute_mma_ds; + alignas(16) typename PipelineMmaComputeDKDV::SharedStorage mma_compute_dkdv; + }; + + template + static CUTE_DEVICE constexpr auto restage(Layout const& layout, Stages stages = {}) { + return composition(layout, make_tuple(_, _, _, make_layout(stages))); + } + + using SmemLayoutK = decltype(restage(typename CollectiveMmaKQ::SmemLayoutA{})); + using SmemLayoutV = decltype(restage(typename CollectiveMmaVDO::SmemLayoutA{})); + using SmemLayoutQ = decltype(restage(typename CollectiveMmaKQ::SmemLayoutB{}, _2{})); + using SmemLayoutDO = decltype(restage(typename CollectiveMmaVDO::SmemLayoutB{}, _1{})); + using SmemLayoutDS = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, Int{})); + using SmemLayoutLSE = Layout>; + using SmemLayoutSumOdO = Layout>; + + using SmemLayoutQT = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutB{}, _2{})); + using SmemLayoutKT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutB{})); + using SmemLayoutDST = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutA{}, Int{})); + using SmemLayoutDOT = decltype(restage(typename CollectiveMmaPDO::SmemLayoutB{}, _1{})); + + using TileShapeDQ = _32; + using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::K, ElementAcc, TileShapeQ, TileShapeDQ + >()); + using SmemShapeDQ = Shape>; + using SmemLayoutDQ = decltype(tile_to_shape(SmemAtomDQ{}, SmemShapeDQ{}, Step<_2, _1, _3>{})); + + struct TensorStorage { + union { + alignas(2048) cute::array> smem_k; + alignas(2048) cute::array> smem_k_t; + }; + alignas(2048) cute::array> smem_v; + union { + alignas(2048) cute::array> smem_q; + alignas(2048) cute::array> smem_q_t; + }; + union { + alignas(2048) cute::array> smem_do; + alignas(2048) cute::array> smem_do_t; + }; + union { + alignas(2048) cute::array> smem_ds; + alignas(2048) cute::array> smem_ds_t; + }; + alignas(1024) cute::array> smem_dq; + alignas(16) cute::array> smem_lse; + alignas(16) cute::array> smem_sum_odo; + }; + + static constexpr int kTransactionsBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + static constexpr int kTransactionsBytesLoadDO = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutDO{})) * cute::sizeof_bits_v); + + static constexpr int kTransactionsBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static constexpr int kTransactionsBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + struct SharedStorage { + TensorStorage tensors; + PipelineStorage pipelines; + uint32_t tmem_base_ptr; + }; + + // this is tight enough that it won't work with sizeof due to padding for alignment + static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem"); + + using ProblemShape = Shape>; // Q K D (H B), eventuall D = (D_QK, D_VO) + using TensorStride = TensorStrideContiguousK; // S D (H B) + using RowTensorStride = Stride<_1, Stride>; // S (H B) + + struct MainloopArguments { + const Element* ptr_q; + TensorStride stride_q; + const Element* ptr_k; + TensorStride stride_k; + const Element* ptr_v; + TensorStride stride_v; + const Element* ptr_do; + TensorStride stride_do; + + const ElementAcc* ptr_lse; + RowTensorStride stride_lse; + + const ElementAcc* ptr_sum_odo; + RowTensorStride stride_sum_odo; + + ElementAcc* ptr_dq_acc; + TensorStride stride_dq_acc; + + ElementAcc softmax_scale = 1.0f / sqrtf(TileShapeDQK{}); + }; + + using TMA_K = typename CollectiveMmaKQ::Params::TMA_A; + using TMA_V = typename CollectiveMmaVDO::Params::TMA_A; + using TMA_Q = typename CollectiveMmaKQ::Params::TMA_B; + using TMA_DO = typename CollectiveMmaVDO::Params::TMA_B; + + using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{}, + make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(1, 1)), TensorStride{}), + SmemLayoutDQ{}(_, _, _0{}) + )); + + struct MainloopParams { + TMA_K tma_load_k; + TMA_V tma_load_v; + TMA_Q tma_load_q; + TMA_DO tma_load_do; + TMA_DQ tma_red_dq; + }; + + struct EpilogueArguments { + Element* ptr_dk; + TensorStride stride_dk; + Element* ptr_dv; + TensorStride stride_dv; + }; + + struct Arguments { + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_shape; + MainloopArguments mainloop; + MainloopParams mainloop_params; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + }; + + + static bool can_implement(Arguments const& args) { + auto [Q, K, D, HB] = args.problem_shape; + auto [H, B] = HB; + if (Q <= 0 || K <= 0 || D <= 0 || H <= 0 || B <= 0) { + return false; + } + if (D % Alignment != 0) { + return false; + } + return true; + } + + + static Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return Status::kSuccess; + } + + + static Params to_underlying_arguments(Arguments const& args, void*) { + auto [Q, K, D, HB] = args.problem_shape; + + auto params_kq = CollectiveMmaKQ::to_underlying_arguments( + make_shape(K, Q, D, HB), + typename CollectiveMmaKQ::Arguments { + args.mainloop.ptr_k, args.mainloop.stride_k, + args.mainloop.ptr_q, args.mainloop.stride_q, + }, /*workspace=*/nullptr); + + auto params_vdo = CollectiveMmaVDO::to_underlying_arguments( + make_shape(K, Q, D, HB), + typename CollectiveMmaVDO::Arguments { + args.mainloop.ptr_v, args.mainloop.stride_v, + args.mainloop.ptr_do, args.mainloop.stride_do, + }, /*workspace=*/nullptr); + + TMA_DQ tma_red_dq = make_tma_copy( + SM90_TMA_REDUCE_ADD{}, + make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q, D, HB), args.mainloop.stride_dq_acc), + SmemLayoutDQ{}(_, _, _0{}) + ); + + return Params{ + args.problem_shape, + args.mainloop, + MainloopParams{ + params_kq.tma_load_a, + params_vdo.tma_load_a, + params_kq.tma_load_b, + params_vdo.tma_load_b, + tma_red_dq + }, + args.epilogue, + args.hw_info + }; + } + + + template + static CUTLASS_DEVICE auto quantize(T const& input) { + constexpr int AlignmentS = 4; + auto output = make_tensor(shape(input)); + auto input_vec = recast>(input); + auto output_vec = recast>(output); + + cutlass::NumericArrayConverter epilogue_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(input_vec); i++) { + output_vec(i) = epilogue_op(input_vec(i)); + } + + return output; + } + + + template + CUTLASS_DEVICE void load( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_producer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, + typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_producer_state, + PipelineLoadComputeLSE& pipeline_load_compute_lse, + typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_producer_state, + PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, + typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) { + + auto [Q, K, D, HB] = problem_shape; + + using X = Underscore; + + uint16_t mcast_mask = 0; + + auto mK = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB)); + auto mQ = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB)); + auto mV = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D, HB)); + auto mDO = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D, HB)); + + auto gK = local_tile(mK, TileShapeKQ{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gQ = local_tile(mQ, TileShapeKQ{}, make_coord(_,_,_), Step{}); + auto gV = local_tile(mV, TileShapeVDO{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gDO = local_tile(mDO, TileShapeVDO{}, make_coord(_,_,_), Step{}); + + ThrMMA cta_mma_kq = TiledMmaKQ{}.get_slice(_0{}); + ThrMMA cta_mma_vdo = TiledMmaVDO{}.get_slice(_0{}); + + auto tSTgK = cta_mma_kq.partition_A(gK); + auto tSTgQ = cta_mma_kq.partition_B(gQ); + auto tDPTgV = cta_mma_vdo.partition_A(gV); + auto tDPTgDO = cta_mma_vdo.partition_B(gDO); + + auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); + auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); + auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); + + auto [tKgK_mkl, tKsK] = tma_partition( + mainloop_params.tma_load_k, _0{}, make_layout(_1{}), + group_modes<0,3>(sK), group_modes<0,3>(tSTgK)); + auto [tQgQ_mkl, tQsQ] = tma_partition( + mainloop_params.tma_load_q, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSTgQ)); + auto [tVgV_mkl, tVsV] = tma_partition( + mainloop_params.tma_load_v, _0{}, make_layout(_1{}), + group_modes<0,3>(sV), group_modes<0,3>(tDPTgV)); + auto [tDOgDO_mkl, tDOsDO] = tma_partition( + mainloop_params.tma_load_do, _0{}, make_layout(_1{}), + group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO)); + + // set up lse and sum_odo + + auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord; + + pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); + auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); + + pipeline_load_mma_q.producer_expect_transaction(pipeline_load_mma_q_producer_state, kTransactionsBytesLoadK); + + // load K + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_k.with(*tma_barrier, mcast_mask), + tKgK_mkl(_, blk_coord_k, _0{}, blk_coord_batch), + tKsK(_, _0{}) + ); + } + + // load Q + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), + tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), + tQsQ(_, pipeline_load_mma_q_producer_state.index()) + ); + } + + ++pipeline_load_mma_q_producer_state; + + pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); + + // load LSE + // 32 threads loading 128 values of 32b each + // so 4*32b=128b + + int thread_idx = threadIdx.x % NumThreadsPerWarp; + int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4; + int gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse); + cutlass::arch::cp_async_zfill<16>( + shared_tensors.smem_lse.begin() + smem_idx, + &mLSE(gmem_idx, blk_coord_batch), + gmem_idx < Q + ); + + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_lse_producer_state; + + + pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); + tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); + + pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV); + + // load V + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_v.with(*tma_barrier, mcast_mask), + tVgV_mkl(_, blk_coord_k, _0{}, blk_coord_batch), + tVsV(_, _0{}) + ); + } + + // load dO + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), + tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), + tDOsDO(_, pipeline_load_mma_do_producer_state.index()) + ); + } + + ++pipeline_load_mma_do_producer_state; + + pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); + + // load sum_OdO + smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo); + cutlass::arch::cp_async<16>( + shared_tensors.smem_sum_odo.begin() + smem_idx, + &mSumOdO(gmem_idx, blk_coord_batch), + gmem_idx < Q + ); + + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_sum_odo_producer_state; + + iter_count -= 1; + iter_index += 1; + + while (iter_count > 0) { + pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); + tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); + + // load Q + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), + tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), + tQsQ(_, pipeline_load_mma_q_producer_state.index()) + ); + } + + ++pipeline_load_mma_q_producer_state; + + pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); + + // load LSE + smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + cutlass::arch::cp_async<16>( + shared_tensors.smem_lse.begin() + smem_idx, + &mLSE(gmem_idx, blk_coord_batch), + gmem_idx < Q + ); + + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_lse_producer_state; + + pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); + tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); + + // load dO + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), + tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), + tDOsDO(_, pipeline_load_mma_do_producer_state.index()) + ); + } + + ++pipeline_load_mma_do_producer_state; + + pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); + + // load sum_OdO + smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + cutlass::arch::cp_async_zfill<16>( + shared_tensors.smem_sum_odo.begin() + smem_idx, + &mSumOdO(gmem_idx, blk_coord_batch), + gmem_idx < Q + ); + + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_sum_odo_producer_state; + + iter_count -= 1; + iter_index += 1; + } + } + + + template + CUTLASS_DEVICE void mma( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + TensorStorage& shared_tensors, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, + typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state, + PipelineMmaComputeS& pipeline_mma_compute_s, + typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state, + PipelineMmaComputeDP& pipeline_mma_compute_dp, + typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state, + PipelineComputeMmaP& pipeline_compute_mma_p, + typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state, + PipelineComputeMmaDS& pipeline_compute_mma_ds, + typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) { + + auto [Q, K, D, HB] = problem_shape; + + auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); + auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); + auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); + + auto sQT = make_tensor(make_smem_ptr(shared_tensors.smem_q_t.begin()), SmemLayoutQT{}); + auto sKT = make_tensor(make_smem_ptr(shared_tensors.smem_k_t.begin()), SmemLayoutKT{}); + auto sDS = make_tensor(make_smem_ptr(shared_tensors.smem_ds.begin()), SmemLayoutDS{}); + auto sDST = make_tensor(make_smem_ptr(shared_tensors.smem_ds_t.begin()), SmemLayoutDST{}); + auto sP = make_tensor(make_smem_ptr((Element*) nullptr), typename CollectiveMmaPDO::SmemLayoutA{}); + auto sDOT = make_tensor(make_smem_ptr(shared_tensors.smem_do_t.begin()), SmemLayoutDOT{}); + + Tensor tSTrK = TiledMmaKQ::make_fragment_A(sK); + Tensor tSTrQ = TiledMmaKQ::make_fragment_B(sQ); + + Tensor tDPTrV = TiledMmaVDO::make_fragment_A(sV); + Tensor tDPTrDO = TiledMmaVDO::make_fragment_B(sDO); + + Tensor tDQrDS = TiledMmaDSK::make_fragment_A(sDS); + Tensor tDQrKT = TiledMmaDSK::make_fragment_B(sKT); + + Tensor tDKrDST = TiledMmaDSQ::make_fragment_A(sDST); + Tensor tDKrQT = TiledMmaDSQ::make_fragment_B(sQT); + + Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{}); + tDVrP.data() = TmemAllocation::kP; + Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT); + + TiledMmaKQ tiled_mma_kq; + TiledMmaVDO tiled_mma_vdo; + TiledMmaDSK tiled_mma_dsk; + TiledMmaDSQ tiled_mma_dsq; + TiledMmaPDO tiled_mma_pdo; + + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::Zero; + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::Zero; + + Tensor tSTtST = partition_fragment_C(tiled_mma_kq, select<0,1>(TileShapeKQ{})); + tSTtST.data() = TmemAllocation::kS; + + Tensor tDPTtDPT = partition_fragment_C(tiled_mma_vdo, select<0,1>(TileShapeVDO{})); + tDPTtDPT.data() = TmemAllocation::kDP; + + Tensor tDQtDQ = partition_fragment_C(tiled_mma_dsk, select<0,1>(TileShapeDSK{})); + tDQtDQ.data() = TmemAllocation::kDQ; + + Tensor tDKtDK = partition_fragment_C(tiled_mma_dsq, select<0,1>(TileShapeDSQ{})); + tDKtDK.data() = TmemAllocation::kDK; + + Tensor tDVtDV = partition_fragment_C(tiled_mma_pdo, select<0,1>(TileShapePDO{})); + tDVtDV.data() = TmemAllocation::kDV; + + auto pipeline_load_mma_q_release_state = pipeline_load_mma_q_consumer_state; + + pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); + pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); + + // S = Q*K + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { + cute::gemm(tiled_mma_kq, + tSTrK(_,_,k_block,_0{}), + tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), + tSTtST); + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::One; + } + + ++pipeline_load_mma_q_consumer_state; + + pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); + ++pipeline_mma_compute_s_producer_state; + + pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); + + pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); + pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); + + // dP = dO*V + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { + cute::gemm(tiled_mma_vdo, + tDPTrV(_,_,k_block,_0{}), + tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDPTtDPT); + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); + ++pipeline_mma_compute_dp_producer_state; + + pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); + + // dV = P*dO + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { + cute::gemm(tiled_mma_pdo, + tDVrP(_,_,k_block), + tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDVtDV); + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); + ++pipeline_compute_mma_p_consumer_state; + + pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); + ++pipeline_load_mma_do_consumer_state; + + iter_count -= 1; + + // in tmem, S & P overlap + // and dP and dQ overlap + // so we need to acquire dQ and dP at the same time + while (iter_count > 0) { + pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); + pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); + + // S = Q*K + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { + cute::gemm(tiled_mma_kq, + tSTrK(_,_,k_block,_0{}), + tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), + tSTtST); + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::One; + } + + ++pipeline_load_mma_q_consumer_state; + + pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); + ++pipeline_mma_compute_s_producer_state; + + pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); + + // we need to acquire dP here, because tmem dQ == tmem dP + pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); + + // dQ = dS*K + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { + cute::gemm(tiled_mma_dsk, + tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDQrKT(_,_,k_block,_0{}), + tDQtDQ); + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); + ++pipeline_mma_reduce_dq_producer_state; + + // dK = dS*Q + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { + cute::gemm(tiled_mma_dsq, + tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), + tDKtDK); + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); + ++pipeline_load_mma_q_release_state; + + pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); + ++pipeline_compute_mma_ds_consumer_state; + + // we grab dq here, because in tmem dq == dp + pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); + + pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); + + // dP = dO*V + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { + cute::gemm(tiled_mma_vdo, + tDPTrV(_,_,k_block,_0{}), + tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDPTtDPT); + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); + ++pipeline_mma_compute_dp_producer_state; + + pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); + + // dV = P*dO + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { + cute::gemm(tiled_mma_pdo, + tDVrP(_,_,k_block), + tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDVtDV); + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); + ++pipeline_compute_mma_p_consumer_state; + + pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); + ++pipeline_load_mma_do_consumer_state; + + iter_count -= 1; + } + + // signal to the epilogue that dV is ready + pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); + pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); + ++pipeline_mma_compute_dkdv_producer_state; + + pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); + + pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); + + // dK = dS*Q + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { + cute::gemm(tiled_mma_dsq, + tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), + tDKtDK); + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; + } + + // signal to epilgue that dK is ready + pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); + ++pipeline_mma_compute_dkdv_producer_state; + + // we've already acquired mma_reduce_dq in the loop + + // dQ = dS*K + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { + cute::gemm(tiled_mma_dsk, + tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDQrKT(_,_,k_block,_0{}), + tDQtDQ); + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); + ++pipeline_mma_reduce_dq_producer_state; + + pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); + ++pipeline_load_mma_q_release_state; + + pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); + ++pipeline_compute_mma_ds_consumer_state; + } + + + + template + CUTLASS_DEVICE void store( + TensorG gmem, + TensorR const& regs, + TensorC const& coord, + TensorShape const& tensor_shape) { + + auto copy_op = make_cotiled_copy( + Copy_Atom, Element>{}, + make_layout(make_shape(_1{}, Int{})), + regs.layout() + ); + auto thr_copy = copy_op.get_slice(_0{}); + + auto tCg = thr_copy.partition_D(gmem); + auto tCr = thr_copy.partition_S(quantize(regs)); + auto tCc = thr_copy.partition_D(coord); + + constexpr int R = decltype(tCr.layout())::rank; + auto tCg_v = group_modes<1, R>(tCg); + auto tCr_v = group_modes<1, R>(tCr); + auto tCc_v = group_modes<1, R>(tCc); + auto tCp_v = make_tensor(shape<1>(tCc_v)); + + for (int i = 0; i < size(tCp_v); ++i) { + tCp_v(i) = elem_less(tCc_v(_0{},i), tensor_shape); + } + + copy_if(copy_op, tCp_v, tCr_v, tCg_v); + } + + + template + CUTLASS_DEVICE void epilogue( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { + + auto [Q, K, D, HB] = problem_shape; + auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord; + + auto load_op = SM100_TMEM_LOAD_32dp32b16x{}; + + auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{}); + tDKtDK.data() = TmemAllocation::kDK; + + auto mDK = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); + auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDK = domain_offset( + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapeDSQ{})) + ); + + constexpr int kNumWarpgroups = kNumComputeWarps / 4; + int dp_idx = threadIdx.x % 128; + int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; + + auto split_wg = [&](auto const& t) { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); + return p(_, _, make_coord(wg_idx, _)); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); + return p(_, _, _, make_coord(wg_idx, _)); + } + }; + + auto tiled_t2r_dk = make_tmem_copy(load_op, tDKtDK); + auto thread_t2r_dk = tiled_t2r_dk.get_slice(dp_idx); + + Tensor tTR_cDK = split_wg(thread_t2r_dk.partition_D(cDK)); + Tensor tTR_gDK = split_wg(thread_t2r_dk.partition_D(gDK)); + Tensor tTR_rDK = make_tensor(shape(tTR_cDK)); + Tensor tTR_tDK = split_wg(thread_t2r_dk.partition_S(tDKtDK)); + + auto tDVtDV = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{}); + tDVtDV.data() = TmemAllocation::kDV; + + auto mDV = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); + auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDV = domain_offset( + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapePDO{})) + ); + + auto tiled_t2r_dv = make_tmem_copy(load_op, tDVtDV); + auto thread_t2r_dv = tiled_t2r_dv.get_slice(dp_idx); + + Tensor tTR_cDV = split_wg(thread_t2r_dv.partition_D(cDV)); + Tensor tTR_gDV = split_wg(thread_t2r_dv.partition_D(gDV)); + Tensor tTR_rDV = make_tensor(shape(tTR_cDV)); + Tensor tTR_tDV = split_wg(thread_t2r_dv.partition_S(tDVtDV)); + + pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); + + // load tDVtDV + cute::copy(tiled_t2r_dv, tTR_tDV, tTR_rDV); + + // store tDVgDV + store(tTR_gDV, tTR_rDV, tTR_cDV, select<1,2>(problem_shape)); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); + ++pipeline_mma_compute_dkdv_consumer_state; + + pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); + + // load tDKtDK + cute::copy(tiled_t2r_dk, tTR_tDK, tTR_rDK); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rDK); i++) { + tTR_rDK(i) = mainloop_args.softmax_scale * tTR_rDK(i); + } + + // store tDKgDK + store(tTR_gDK, tTR_rDK, tTR_cDK, select<1,2>(problem_shape)); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); + ++pipeline_mma_compute_dkdv_consumer_state; + + } + + + template + CUTLASS_DEVICE void compute( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args, + TensorStorage& shared_tensors, + PipelineLoadComputeLSE& pipeline_load_compute_lse, + typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_consumer_state, + PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, + typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_consumer_state, + PipelineMmaComputeS& pipeline_mma_compute_s, + typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_consumer_state, + PipelineMmaComputeDP& pipeline_mma_compute_dp, + typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_consumer_state, + PipelineComputeMmaP& pipeline_compute_mma_p, + typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_producer_state, + PipelineComputeMmaDS& pipeline_compute_mma_ds, + typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_producer_state, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { + + + auto [Q, K, D, HB] = problem_shape; + + // in tmem, S & P overlap + // and dP and dQ overlap + + // there are two compute wg's that cooperatively compute softmax + // they are striped by this tmem atom, i.e. wg0 has 16 elems, then wg1 etc + + auto load_op = SM100_TMEM_LOAD_32dp32b16x{}; + auto store_op = SM100_TMEM_STORE_32dp32b8x{}; + + Tensor tSTtST = partition_fragment_C(TiledMmaKQ{}, select<0,1>(TileShapeKQ{}))(make_coord(_,_),_0{},_0{}); + tSTtST.data() = TmemAllocation::kS; + + Tensor tDPTtDPT = partition_fragment_C(TiledMmaVDO{}, select<0,1>(TileShapeVDO{}))(make_coord(_,_),_0{},_0{}); + tDPTtDPT.data() = TmemAllocation::kDP; + + Tensor cST = make_identity_tensor(take<0,2>(TileShapeKQ{})); + Tensor cDPT = make_identity_tensor(take<0,2>(TileShapeVDO{})); + + constexpr int kNumWarpgroups = kNumComputeWarps / 4; + int dp_idx = threadIdx.x % 128; + int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; + auto tiled_t2r = make_tmem_copy(load_op, tSTtST); + auto thread_t2r = tiled_t2r.get_slice(dp_idx); + + auto split_wg = [&](auto const& t) { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); + return p(_, _, make_coord(wg_idx, _)); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); + return p(_, _, _, make_coord(wg_idx, _)); + } + }; + + Tensor tTR_cST = split_wg(thread_t2r.partition_D(cST)); + Tensor tTR_rST = make_tensor(shape(tTR_cST)); + Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST)); + + Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT); + Tensor tTR_cDPT = split_wg(tTR_cDPT_p); + Tensor tTR_rDPT = make_tensor(shape(tTR_cDPT)); + Tensor tTR_tDPT = split_wg(thread_t2r.partition_S(tDPTtDPT)); + + Tensor sLSE = make_tensor(make_smem_ptr(shared_tensors.smem_lse.begin()), SmemLayoutLSE{}); + Tensor sSumOdO = make_tensor(make_smem_ptr(shared_tensors.smem_sum_odo.begin()), SmemLayoutSumOdO{}); + + auto sP = make_tensor(make_smem_ptr((Element*) nullptr), typename CollectiveMmaPDO::SmemLayoutA{}); + + + auto tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{}); + auto tDVcST = TiledMmaPDO{}.get_slice(_0{}).partition_A(cST); + tDVrP.data() = TmemAllocation::kP; + + auto tiled_r2t = make_tmem_copy(store_op, tDVrP); + auto thread_r2t = tiled_r2t.get_slice(dp_idx); + + auto tRT_tP = split_wg(thread_r2t.partition_D(tDVrP)); + auto tRT_cST = split_wg(thread_r2t.partition_S(tDVcST)); + + CUTLASS_PRAGMA_NO_UNROLL + while (iter_count > 0) { + // wait for S and P + pipeline_mma_compute_s.consumer_wait(pipeline_mma_compute_s_consumer_state); + pipeline_compute_mma_p.producer_acquire(pipeline_compute_mma_p_producer_state); + // wait for LSE + pipeline_load_compute_lse.consumer_wait(pipeline_load_compute_lse_consumer_state); + + auto dispatch_bool = [](bool b, auto fn) { + if (b) { + fn(cute::true_type{}); + } + else { + fn(cute::false_type{}); + } + }; + + dispatch_bool(std::is_base_of_v && + warp_uniform(iter_index == get<1>(blk_coord)), [&](auto is_causal_masked_tile) { + + // compute P = softmax(S, LSE) + cute::copy(tiled_t2r, tTR_tST, tTR_rST); + + if constexpr (std::is_base_of_v && decltype(is_causal_masked_tile)::value) { + Mask{}.apply_mask(tTR_rST, [&](int i) { + auto c_transpose = tTR_cST(i); + return make_coord(get<1>(c_transpose) + iter_index * TileShapeQ{}, get<0>(c_transpose) + get<1>(blk_coord) * TileShapeK{}); + }, problem_shape); + } + + ElementAcc log2_e = static_cast(M_LOG2E); + float2 softmax_scale_log2_e; + softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e; + softmax_scale_log2_e.y = mainloop_args.softmax_scale * log2_e; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rST); i += 2) { + float2 acc; + float2 lse; + float2 out; + acc.x = tTR_rST(i); + acc.y = tTR_rST(i + 1); + lse.x = sLSE(get<1>(tTR_cST(i)), pipeline_load_compute_lse_consumer_state.index()); + lse.y = sLSE(get<1>(tTR_cST(i+1)), pipeline_load_compute_lse_consumer_state.index()); + cute::fma(out, softmax_scale_log2_e, acc, lse); + tTR_rST(i) = ::exp2f(out.x); + tTR_rST(i+1) = ::exp2f(out.y); + } + + auto tRT_rST = quantize(tTR_rST); + auto tRT_rST_reshaped = make_tensor(tRT_rST.data(), shape(tRT_cST)); + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransformBarrier + ).arrive_and_wait(); + + cute::copy(tiled_r2t, tRT_rST_reshaped, tRT_tP); + }); + + // notify for P + cutlass::arch::fence_view_async_tmem_store(); + pipeline_compute_mma_p.producer_commit(pipeline_compute_mma_p_producer_state); + ++pipeline_compute_mma_p_producer_state; + // release S + pipeline_mma_compute_s.consumer_release(pipeline_mma_compute_s_consumer_state); + ++pipeline_mma_compute_s_consumer_state; + // release LSE + pipeline_load_compute_lse.consumer_release(pipeline_load_compute_lse_consumer_state); + ++pipeline_load_compute_lse_consumer_state; + + // wait for OdO + pipeline_load_compute_sum_odo.consumer_wait(pipeline_load_compute_sum_odo_consumer_state); + // wait for dP + pipeline_mma_compute_dp.consumer_wait(pipeline_mma_compute_dp_consumer_state); + + // wait for dS + // in principle, we could defer waiting for dS, and move in the freeing of dP + // however, that would force us to keep dS in registers longer + pipeline_compute_mma_ds.producer_acquire(pipeline_compute_mma_ds_producer_state); + + // compute dS = dsoftmax(P, dP, sum_OdO) + cute::copy(tiled_t2r, tTR_tDPT, tTR_rDPT); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rDPT); i += 2) { + float2 st; + st.x = tTR_rST(i); + st.y = tTR_rST(i+1); + float2 dpt; + dpt.x = tTR_rDPT(i); + dpt.y = tTR_rDPT(i+1); + float2 odo; + odo.x = sSumOdO(get<1>(tTR_cDPT(i)), pipeline_load_compute_sum_odo_consumer_state.index()); + odo.y = sSumOdO(get<1>(tTR_cDPT(i+1)), pipeline_load_compute_sum_odo_consumer_state.index()); + float2 dif; + // sum odo is negated during preprocess + cute::add(dif, dpt, odo); + float2 out; + cute::mul(out, dif, st); + tTR_rDPT(i) = out.x; + tTR_rDPT(i+1) = out.y; + } + + auto tTR_rDST = quantize(tTR_rDPT); + + // release dP + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dp.consumer_release(pipeline_mma_compute_dp_consumer_state); + ++pipeline_mma_compute_dp_consumer_state; + + Tensor sDS = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_ds.begin()), SmemLayoutDS{}) + (_, _, _, pipeline_compute_mma_ds_producer_state.index()); + + auto thread_layout = make_ordered_layout( + make_shape(_128{}, _128{}), + make_stride(_1{}, _0{}) + ); + + auto sDS_pi = as_position_independent_swizzle_tensor(sDS); + auto sDS_pi_slice_p = sDS_pi.compose(thread_layout)(dp_idx, _).compose(make_layout(shape(tTR_cDPT_p))); + auto sDS_pi_slice = split_wg(sDS_pi_slice_p); + + copy_aligned(tTR_rDST, sDS_pi_slice); + + // notify for dS + cutlass::arch::fence_view_async_shared(); + pipeline_compute_mma_ds.producer_commit(pipeline_compute_mma_ds_producer_state); + ++pipeline_compute_mma_ds_producer_state; + // release OdO + pipeline_load_compute_sum_odo.consumer_release(pipeline_load_compute_sum_odo_consumer_state); + ++pipeline_load_compute_sum_odo_consumer_state; + + iter_count -= 1; + iter_index += 1; + } + + epilogue( + blk_coord, problem_shape, mainloop_args, epilogue_args, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state + ); + } + + template + CUTLASS_DEVICE void reduce( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state, + PipelineReduceTmaStore& pipeline_reduce_tma_store, + typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) { + + using X = Underscore; + + auto [Q, K, D, HB] = problem_shape; + + auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord; + + // must match TileShapeDQ + auto load_op = SM100_TMEM_LOAD_32dp32b32x{}; + + auto tDQtDQ = partition_fragment_C(TiledMmaDSK{}, select<0,1>(TileShapeDSK{}))(make_coord(_,_),_0{},_0{}); + tDQtDQ.data() = TmemAllocation::kDQ; + + Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB)); + auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, _, _0{}, blk_coord_batch); + + Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{})); + + Tensor sDQ = make_tensor(make_smem_ptr(shared_tensors.smem_dq.begin()), SmemLayoutDQ{}); + + int thread_idx = threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp); + auto tiled_t2r = make_tmem_copy(load_op, tDQtDQ); + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + + Tensor tTR_cDQ = thread_t2r.partition_D(cDQ); + Tensor tTR_gDQ = thread_t2r.partition_D(gDQ); + Tensor tTR_sDQ = thread_t2r.partition_D(sDQ); + Tensor tTR_tDQ = thread_t2r.partition_S(tDQtDQ); + + auto block_tma = mainloop_params.tma_red_dq.get_slice(_0{}); + + Tensor tDQsDQ = block_tma.partition_S(sDQ); + Tensor tDQcDQ = block_tma.partition_S(cDQ); + Tensor tDQgDQ = block_tma.partition_D(gDQ); + + int lane_predicate = (threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp)) == 0; + + while (iter_count > 0) { + pipeline_mma_reduce_dq.consumer_wait(pipeline_mma_reduce_dq_consumer_state); + + Tensor tTR_rDQ = make_tensor(shape(tTR_cDQ)); + + // load dQ from tmem to rmem + cute::copy(tiled_t2r, tTR_tDQ, tTR_rDQ); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_reduce_dq.consumer_release(pipeline_mma_reduce_dq_consumer_state); + ++pipeline_mma_reduce_dq_consumer_state; + + // we don't have enough smem to dump it all to smem, so we do it in stages + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<2>(tTR_cDQ); i++) { + if (lane_predicate) { + pipeline_reduce_tma_store.producer_acquire(pipeline_reduce_tma_store_producer_state); + } + // wait in all threads for the acquire to complete + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier + ).arrive_and_wait(); + + cute::copy(tTR_rDQ(_, _, i), tTR_sDQ(_, _, _0{}, pipeline_reduce_tma_store_producer_state.index())); + + // wait for the stores to all be visible to the TMA + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier + ).arrive_and_wait(); + if (lane_predicate) { + // launch tma store + copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index)); + pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state); + } + + ++pipeline_reduce_tma_store_producer_state; + } + + iter_count -= 1; + iter_index += 1; + } + } + + + CUTLASS_DEVICE void operator()(Params const& params, char* smem) { + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_role(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + if (role == WarpRole::Load && lane_predicate) { + prefetch_tma_descriptor(params.mainloop_params.tma_load_q.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_k.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_v.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_do.get_tma_descriptor()); + } + + SharedStorage& shared_storage = *reinterpret_cast(smem); + + int initializing_warp = 0; + typename PipelineLoadMmaQ::Params pipeline_load_mma_q_params; + if (role == WarpRole::Load) { + pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Producer; + } + if (role == WarpRole::Mma) { + pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Consumer; + } + pipeline_load_mma_q_params.is_leader = lane_predicate && (role == WarpRole::Load); + // Also loads K in the first iteration + pipeline_load_mma_q_params.transaction_bytes = kTransactionsBytesLoadQ; + pipeline_load_mma_q_params.initializing_warp = initializing_warp++; + PipelineLoadMmaQ pipeline_load_mma_q(shared_storage.pipelines.load_mma_q, pipeline_load_mma_q_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineLoadMmaDO::Params pipeline_load_mma_do_params; + if (role == WarpRole::Load) { + pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Producer; + } + if (role == WarpRole::Mma) { + pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Consumer; + } + pipeline_load_mma_do_params.is_leader = lane_predicate && (role == WarpRole::Load); + // Also loads V in the first iteration + pipeline_load_mma_do_params.transaction_bytes = kTransactionsBytesLoadDO; + pipeline_load_mma_do_params.initializing_warp = initializing_warp++; + PipelineLoadMmaDO pipeline_load_mma_do(shared_storage.pipelines.load_mma_do, pipeline_load_mma_do_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineLoadComputeLSE::Params pipeline_load_compute_lse_params; + if (role == WarpRole::Load) { + pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Consumer; + } + pipeline_load_compute_lse_params.producer_arv_count = NumThreadsPerWarp; + pipeline_load_compute_lse_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; + pipeline_load_compute_lse_params.initializing_warp = initializing_warp++; + PipelineLoadComputeLSE pipeline_load_compute_lse( + shared_storage.pipelines.load_compute_lse, + pipeline_load_compute_lse_params, + /*barrier init*/ cute::true_type{}); + + typename PipelineLoadComputeSumOdO::Params pipeline_load_compute_sum_odo_params; + if (role == WarpRole::Load) { + pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Consumer; + } + pipeline_load_compute_sum_odo_params.producer_arv_count = NumThreadsPerWarp; + pipeline_load_compute_sum_odo_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; + pipeline_load_compute_sum_odo_params.initializing_warp = initializing_warp++; + PipelineLoadComputeSumOdO pipeline_load_compute_sum_odo( + shared_storage.pipelines.load_compute_sum_odo, + pipeline_load_compute_sum_odo_params, + /*barrier init*/ cute::true_type{}); + + typename PipelineMmaComputeS::Params pipeline_mma_compute_s_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Consumer; + } + pipeline_mma_compute_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_s_params.initializing_warp = initializing_warp++; + PipelineMmaComputeS pipeline_mma_compute_s( + shared_storage.pipelines.mma_compute_s, + pipeline_mma_compute_s_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaComputeDP::Params pipeline_mma_compute_dp_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Consumer; + } + pipeline_mma_compute_dp_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_dp_params.initializing_warp = initializing_warp++; + PipelineMmaComputeDP pipeline_mma_compute_dp( + shared_storage.pipelines.mma_compute_dp, + pipeline_mma_compute_dp_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaReduceDQ::Params pipeline_mma_reduce_dq_params; + if (role == WarpRole::Mma) { + pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Producer; + } + if (role == WarpRole::Reduce) { + pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Consumer; + } + pipeline_mma_reduce_dq_params.consumer_arv_count = kNumReduceWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_reduce_dq_params.initializing_warp = initializing_warp++; + PipelineMmaReduceDQ pipeline_mma_reduce_dq( + shared_storage.pipelines.mma_reduce_dq, + pipeline_mma_reduce_dq_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineComputeMmaP::Params pipeline_compute_mma_p_params; + if (role == WarpRole::Mma) { + pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Consumer; + } + if (role == WarpRole::Compute) { + pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Producer; + } + pipeline_compute_mma_p_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_compute_mma_p_params.consumer_arv_count = 1; + pipeline_compute_mma_p_params.initializing_warp = initializing_warp++; + PipelineComputeMmaP pipeline_compute_mma_p( + shared_storage.pipelines.compute_mma_p, + pipeline_compute_mma_p_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineComputeMmaDS::Params pipeline_compute_mma_ds_params; + if (role == WarpRole::Mma) { + pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Consumer; + } + if (role == WarpRole::Compute) { + pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Producer; + } + pipeline_compute_mma_ds_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_compute_mma_ds_params.consumer_arv_count = 1; + pipeline_compute_mma_ds_params.initializing_warp = initializing_warp++; + PipelineComputeMmaDS pipeline_compute_mma_ds( + shared_storage.pipelines.compute_mma_ds, + pipeline_compute_mma_ds_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaComputeDKDV::Params pipeline_mma_compute_dkdv_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Consumer; + } + pipeline_mma_compute_dkdv_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_dkdv_params.initializing_warp = initializing_warp++; + PipelineMmaComputeDKDV pipeline_mma_compute_dkdv( + shared_storage.pipelines.mma_compute_dkdv, + pipeline_mma_compute_dkdv_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + PipelineReduceTmaStore pipeline_reduce_tma_store; + + TmemAllocator tmem_allocator; + + pipeline_init_arrive_relaxed(size(ClusterShape{})); + + pipeline_load_mma_q.init_masks(ClusterShape{}); + pipeline_load_mma_do.init_masks(ClusterShape{}); + pipeline_mma_compute_s.init_masks(ClusterShape{}); + pipeline_mma_compute_dp.init_masks(ClusterShape{}); + pipeline_mma_reduce_dq.init_masks(ClusterShape{}); + pipeline_compute_mma_p.init_masks(ClusterShape{}); + pipeline_compute_mma_ds.init_masks(ClusterShape{}); + pipeline_mma_compute_dkdv.init_masks(ClusterShape{}); + + typename decltype(pipeline_load_mma_q)::PipelineState pipeline_load_mma_q_consumer_state; + typename decltype(pipeline_load_mma_do)::PipelineState pipeline_load_mma_do_consumer_state; + typename decltype(pipeline_load_compute_lse)::PipelineState pipeline_load_compute_lse_consumer_state; + typename decltype(pipeline_load_compute_sum_odo)::PipelineState pipeline_load_compute_sum_odo_consumer_state; + typename decltype(pipeline_mma_compute_s)::PipelineState pipeline_mma_compute_s_consumer_state; + typename decltype(pipeline_mma_compute_dp)::PipelineState pipeline_mma_compute_dp_consumer_state; + typename decltype(pipeline_mma_reduce_dq)::PipelineState pipeline_mma_reduce_dq_consumer_state; + typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state; + typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state; + typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state; + + auto pipeline_load_mma_q_producer_state = make_producer_start_state(); + auto pipeline_load_mma_do_producer_state = make_producer_start_state(); + auto pipeline_load_compute_lse_producer_state = make_producer_start_state(); + auto pipeline_load_compute_sum_odo_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_s_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_dp_producer_state = make_producer_start_state(); + auto pipeline_mma_reduce_dq_producer_state = make_producer_start_state(); + auto pipeline_compute_mma_p_producer_state = make_producer_start_state(); + auto pipeline_compute_mma_ds_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_dkdv_producer_state = make_producer_start_state(); + auto pipeline_reduce_tma_store_producer_state = make_producer_start_state(); + + pipeline_init_wait(size(ClusterShape{})); + + auto blk_coord = make_coord(_0{}, blockIdx.x, make_coord(blockIdx.y, blockIdx.z)); + auto problem_shape = params.problem_shape; + int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{}); + int iter_start = 0; + if constexpr (std::is_base_of_v) { + iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{}; + } + iter_count -= iter_start; + + if (role == WarpRole::Load) { + warpgroup_reg_set(); + + load( + blk_coord, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_mma_q, pipeline_load_mma_q_producer_state, + pipeline_load_mma_do, pipeline_load_mma_do_producer_state, + pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state, + pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state + ); + + } + else if (role == WarpRole::Mma) { + warpgroup_reg_set(); + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + + mma( + blk_coord, + problem_shape, + iter_start, + iter_count, + params.mainloop, + shared_storage.tensors, + pipeline_load_mma_q, pipeline_load_mma_q_consumer_state, + pipeline_load_mma_do, pipeline_load_mma_do_consumer_state, + pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state, + pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state, + pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_producer_state, + pipeline_compute_mma_p, pipeline_compute_mma_p_consumer_state, + pipeline_compute_mma_ds, pipeline_compute_mma_ds_consumer_state, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_producer_state + ); + + } + else if (role == WarpRole::Compute) { + warpgroup_reg_set(); + + compute( + blk_coord, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.epilogue, + shared_storage.tensors, + pipeline_load_compute_lse, pipeline_load_compute_lse_consumer_state, + pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_consumer_state, + pipeline_mma_compute_s, pipeline_mma_compute_s_consumer_state, + pipeline_mma_compute_dp, pipeline_mma_compute_dp_consumer_state, + pipeline_compute_mma_p, pipeline_compute_mma_p_producer_state, + pipeline_compute_mma_ds, pipeline_compute_mma_ds_producer_state, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state + ); + + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier + ).arrive_and_wait(); + + if (warp_idx % kNumComputeWarps == 0) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } + else if (role == WarpRole::Reduce) { + warpgroup_reg_set(); + + reduce( + blk_coord, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_consumer_state, + pipeline_reduce_tma_store, pipeline_reduce_tma_store_producer_state + ); + + pipeline_reduce_tma_store.producer_tail(pipeline_reduce_tma_store_producer_state); + } + else { + warpgroup_reg_set(); + + /* no-op */ + + } + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static dim3 get_grid_shape(Params const& params) { + auto [Q, K, D, HB] = params.problem_shape; + auto [H, B] = HB; + dim3 grid(ceil_div(K, TileShapeK{}), H, B); + return grid; + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_reduction.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_reduction.hpp new file mode 100644 index 0000000000..c6a0575013 --- /dev/null +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_reduction.hpp @@ -0,0 +1,197 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/arch.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; +template< + class ElementOut, + class ElementAcc, + class ElementScale, + size_t kNumHeads, + size_t kHeadDimLatent, + int kMaxSplits +> +struct Sm100FmhaMlaReductionKernel { + + static const int SharedStorageSize = 0; + static const int MaxThreadsPerBlock = 128; + static const int MinBlocksPerMultiprocessor = 1; + + using ArchTag = cutlass::arch::Sm100; + + static_assert(kHeadDimLatent % MaxThreadsPerBlock == 0); + struct Arguments { + ElementAcc* ptr_oaccum = nullptr; + ElementOut* ptr_o = nullptr; + ElementAcc* ptr_lseaccum = nullptr; + ElementAcc* ptr_lse = nullptr; + ElementScale scale = 1.f; + int num_batches = 0; + int split_kv = -1; + int dim_k = -1; + int* ptr_seq = nullptr; + int* ptr_split_kv = nullptr; + int tile_shape_s = 128; + }; + using Params = Arguments; + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return {args.ptr_oaccum, args.ptr_o, args.ptr_lseaccum, args.ptr_lse, + args.scale, args.num_batches, args.split_kv, args.dim_k, args.ptr_seq, + args.ptr_split_kv, args.tile_shape_s}; + } + + static size_t get_workspace_size(Arguments const& /*args*/) { + return 0; + } + + static Status initialize_workspace( + Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const& params) { + return dim3(kNumHeads, 1, params.num_batches); + } + + static dim3 get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + static bool can_implement(Arguments const& args) { + if (args.num_batches <= 0) return false; + if (args.split_kv <= 0) return false; + return true; + } + + CUTLASS_DEVICE void operator() (Params const& params, char* smem_raw) { + if (params.split_kv <= 1) return; + auto blk_coord = make_coord(blockIdx.x, _0{}, blockIdx.z); + + __shared__ ElementAcc sLseScale[kMaxSplits]; + const size_t offset_lseaccum = get<0>(blk_coord) + kNumHeads * params.split_kv * get<2>(blk_coord); + const size_t offset_lse = get<0>(blk_coord) + kNumHeads * get<2>(blk_coord); + + Tensor gLSEaccum = make_tensor(make_gmem_ptr(params.ptr_lseaccum + offset_lseaccum), + make_shape(params.split_kv), Stride>{}); + + Tensor gLSE = make_tensor(make_gmem_ptr(params.ptr_lse + offset_lse), + Shape<_1>{}, Stride<_1>{}); + + auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)]; + auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)]; + auto k_tile_total = ceil_div(dim_k, params.tile_shape_s); + auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv); + local_split_kv = ceil_div(k_tile_total, k_tile_per_cta); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + if (warp_idx == 0) { + constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32); + + ElementAcc local_lse[kNLsePerThread]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + threadIdx.x; + local_lse[i] = split < local_split_kv ? gLSEaccum(split) : -std::numeric_limits::infinity(); + } + + ElementAcc lse_max = -std::numeric_limits::infinity(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + lse_max = max(lse_max, local_lse[i]); + } + CUTLASS_PRAGMA_UNROLL + for (int offset = 16; offset >= 1; offset /= 2) { + lse_max = max(lse_max, __shfl_xor_sync(0xffffffff, lse_max, offset)); + } + lse_max = lse_max == -std::numeric_limits::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf + lse_max = __shfl_sync(0xffffffff, lse_max, 0); + + ElementAcc sum_lse = 0; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + sum_lse = sum_lse + expf(local_lse[i] - params.scale * lse_max); + } + + CUTLASS_PRAGMA_UNROLL + for (int offset = 16; offset >= 1; offset /= 2) { + sum_lse = sum_lse + __shfl_xor_sync(0xffffffff, sum_lse, offset); + } + + sum_lse = __shfl_sync(0xffffffff, sum_lse, 0); + + ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits::infinity() : logf(sum_lse) + params.scale * lse_max; + if (threadIdx.x == 0 and params.ptr_lse != nullptr) { + gLSE(0) = global_lse; + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + threadIdx.x; + if (split < local_split_kv) { + sLseScale[split] = expf(local_lse[i] - global_lse); + } + } + } + __syncthreads(); + + constexpr int Elements = kHeadDimLatent / MaxThreadsPerBlock; + const size_t offset_oaccum = kHeadDimLatent * params.split_kv * (get<0>(blk_coord) + kNumHeads * get<2>(blk_coord)); + Tensor gOaccum = make_tensor(make_gmem_ptr(params.ptr_oaccum + offset_oaccum), + Shape>{}, Stride<_1>{}); + ElementAcc local_val[Elements] = {0}; + for (int split = 0; split < local_split_kv; ++split) { + ElementAcc lse_scale = sLseScale[split]; + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < Elements; ++i) { + local_val[i] += lse_scale * gOaccum(threadIdx.x + MaxThreadsPerBlock * i); + } + gOaccum.data() = gOaccum.data() + kHeadDimLatent; + } + auto ptr_o_local = params.ptr_o + (get<0>(blk_coord) + get<2>(blk_coord) * kNumHeads) * kHeadDimLatent; + Tensor gO = make_tensor(make_gmem_ptr(ptr_o_local), Shape>{}, Stride<_1>{}); + + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < Elements; ++i) { + gO(threadIdx.x + MaxThreadsPerBlock * i) = static_cast(local_val[i]); + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp new file mode 100644 index 0000000000..acb89a9def --- /dev/null +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp @@ -0,0 +1,2018 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/arch/simd_sm100.hpp" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "gather_tensor.hpp" // from examples/common +#include "common/pow_2.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template< + class TileShape, + class Element_, + class ElementAcc_, + class ElementOut_, + class ElementLSE_, + class TileScheduler, +#ifdef CPASYNC + bool kIsCpAsync = true +#else + bool kIsCpAsync = false +#endif +> +struct Sm100FmhaMlaKernelTmaWarpspecialized { + + using Element = Element_; + using ElementAcc = ElementAcc_; + using ElementOut = ElementOut_; + using ElementLSE = ElementLSE_; + + // only 2Sm mode is supported + static const bool kIs2Sm = true; + static const int MaxThreadsPerBlock = 256; + static const int MinBlocksPerMultiprocessor = 1; + static const int TotalSNum = 2; + static const int TotalPNum = 2; + using ArchTag = cutlass::arch::Sm100; + + using ClusterShape = cute::conditional_t, Shape<_1, _1, _1>>; + + using TileShapeH = tuple_element_t<0, TileShape>; + using TileShapeS = tuple_element_t<1, TileShape>; + using TileShapeD = tuple_element_t<2, TileShape>; + + using TileShapeL = tuple_element_t<0, TileShapeD>; + using TileShapeR = tuple_element_t<1, TileShapeD>; + static_assert(TileShapeL{} % TileShapeR{} == 0, "Rope head dim must divide latent head dim"); + + using ProblemShape = Shape; + using TensorStride = Stride; + using TmemAllocator = cute::conditional_t; + + static_assert(TileShapeH{} == 128); + static const int kWarpsInN = kIs2Sm ? 2 : 1; + + static const int kNumComputeWarps = 4; + static const int kNumLoadWarps = kIsCpAsync ? 2 : 1; + + enum class WarpRole { + kMma = 0x1, kLoad = 0x2, kCompute = 0x3, kLoadPageTable = 0x4, kEmpty=0x0 + }; + + static const long long unsigned int kWarpAssignment = kIsCpAsync ? 0x4221'3333ull : 0x0021'3333ull; + + static CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) { + return static_cast((kWarpAssignment >> (4 * warp_idx)) & 0xF); + } + + static const int Alignment = 128 / sizeof_bits_v; + static const int AlignmentOut = 128 / sizeof_bits_v; + + using TileShapeQK = Shape; + static const int StagesQK = 24 / sizeof(Element); // free parameter + static const int IterationsQKLatent = decltype(TileShapeL{} / get<2>(TileShapeQK{}))::value; + static const int IterationsQKRope = decltype(TileShapeR{} / get<2>(TileShapeQK{}))::value; + static const int IterationsQK = IterationsQKLatent + IterationsQKRope; + + using Schedule = cute::conditional_t; + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStride, Alignment, + Element, TensorStride, Alignment, + ElementAcc, + TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TiledMmaQK = typename CollectiveMmaQK::TiledMma; + using CtaShapeQK = typename CollectiveMmaQK::CtaShape_MNK; + + // chosen for unified smem staging between K and V + using TileShapePV = Shape; + using TransposeTensorStride = decltype(select<1,0,2>(TensorStride{})); + static const int StagesPV = StagesQK; // not sure why, but must be at least two. check pipes + static const int IterationsPV_K = decltype(TileShapeS{} / get<2>(TileShapePV{}))::value; + static const int IterationsPV_N = decltype(TileShapeL{} / get<1>(TileShapePV{}))::value; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStride, Alignment, + Element, TransposeTensorStride, Alignment, + ElementAcc, + TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using CtaShapePV = typename CollectiveMmaPV::CtaShape_MNK; + static_assert(std::is_same_v); + + using TiledMmaPV = typename CollectiveMmaPV::TiledMma; + + using AtomThrShapeMNK = typename CollectiveMmaQK::AtomThrShapeMNK; + static_assert(typename CollectiveMmaQK::AtomThrShapeMNK{} == typename CollectiveMmaPV::AtomThrShapeMNK{}, "schedule must match"); + + static const int StagesPageTable = kIsCpAsync ? StagesPV : 1; + + // pipelines from load to mma, PipelineTmaUmmaAsync, stages tbd + // use expect_tx for Q load + using PipelineLoadQK = cute::conditional_t, PipelineTmaUmmaAsync>; + using PipelineLoadPV = PipelineLoadQK; + // pipeline from mma (Q@K) to softmax, PipelineUmmaAsync, 2 stages + using PipelineS = PipelineUmmaAsync; + // pipeline from softmax (P) to mma (bmm2), PipelineUmmaAsync, 2 stages + using PipelineP = PipelineUmmaConsumerAsync; + // pipeline from mma to softmax (for rescale), PipelineUmmaAsync, 1 stage + using PipelineO = PipelineUmmaAsync<1, AtomThrShapeMNK>; + + using PipelinePT = PipelineAsync; + + struct PipelineStorage { + alignas(16) typename PipelineLoadQK::SharedStorage load_qk; + alignas(16) typename PipelineS::SharedStorage mma_s; + alignas(16) typename PipelineP::SharedStorage p_mma; + alignas(16) typename PipelineO::SharedStorage mma_o; + alignas(16) typename PipelinePT::SharedStorage load_page_table; + }; + + template + static CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) { + return composition(layout, make_tuple(_, _, _, make_layout(stages))); + } + + using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutKC = typename CollectiveMmaQK::SmemLayoutB; + using SmemLayoutVC = typename CollectiveMmaPV::SmemLayoutB; + using SmemLayoutP = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutA{}, make_shape(Int{}, _2{}))); + + static const int kBytesLoadQ = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + static const int kBytesLoadKC = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutKC{})) * cute::sizeof_bits_v); + static const int kBytesLoadVC = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutVC{})) * cute::sizeof_bits_v); + // pre-condition for overlapped smem staging + static_assert(kBytesLoadKC == kBytesLoadVC); + static_assert(StagesQK == StagesPV); + + static const int kTransactionsBytesLoadQK = kBytesLoadKC; + static const int kTransactionsBytesLoadExtraQ = kBytesLoadQ; + static const int kTransactionsBytesLoadPV = kBytesLoadVC; + + static const int kNamedBarrierExchange = (int) cutlass::arch::ReservedNamedBarriers::TransformBarrier; + // This Named Barrier is introduced to solve Q tile loading overwritten issue when enable persistent + // tile scheduler for FP8 MLA. + static const int kNamedBarrierEpilogue = (int) cutlass::arch::ReservedNamedBarriers::EpilogueBarrier; + // + static const int kNamedBarrierTmemDealloc = (int) cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier; + + enum class TmemAllocation : uint32_t { + kSizeS = TileShapeS::value / kWarpsInN, + // Overall + kSizeO = TileShapeL::value / kWarpsInN, + // Between accumulators we loop over + kSizeAccO = decltype(get<1>(TileShapePV{}))::value / kWarpsInN, + kNumS = TotalSNum, + kNumP = TotalPNum, + kNumO = 1, + kS0 = 0, + kS1 = kS0 + kSizeS, + kO0 = kS1 + kSizeS, + kTotal = kO0 + kSizeO + }; + + static_assert(static_cast(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns, "using too much tmem"); + + struct TensorStorage { + // to communicate max and row_sum + cute::array smem_exchange; + cute::array smem_page_table; + alignas(2048) cute::array> smem_q; + union { + alignas(2048) cute::array> smem_kc; + alignas(2048) cute::array> smem_vc; + }; + alignas(2048) cute::array> smem_p; + }; + + struct SharedStorage { + PipelineStorage pipelines; + TensorStorage tensors; + uint32_t tmem_base_ptr; + }; + + static const int SharedStorageSize = sizeof(SharedStorage); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem"); + + struct MainloopArguments { + ElementAcc softmax_scale; + + // all tensors strides are (num_heads or seqlen, head_dim, batch) + // head_dim stride is always 1 + Element* ptr_q_latent; + TensorStride stride_q_latent; + Element* ptr_q_rope; + TensorStride stride_q_rope; + + Element* ptr_c_latent; + TensorStride stride_c_latent; + Element* ptr_k_rope; + TensorStride stride_k_rope; + + // for paged attention, we interpret what was previously [batch, seqlen] + // as [page_count, page_size], and index according to page_table + int* ptr_seq = nullptr; + int* ptr_page_table = nullptr; + // page table is [batch, seqlen or similar] + Stride<_1, int> stride_page_table = {}; + int page_count = 0; + int page_size = TileShapeS{}; // powers of two if kIsCpAsync, otherwise TileShapeS + }; + + struct EpilogueArguments { + ElementOut* ptr_o = nullptr; + TensorStride stride_o; + ElementLSE* ptr_lse = nullptr; + Stride<_1, int> stride_lse; + ElementAcc output_scale = 1.0f; + }; + + struct Arguments { + // (num_heads=128, seqlen, (d_latent=512, d_rope=64), batch_count) + // for paged attention, seqlen is max seqlen + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + int split_kv = -1; + int* ptr_split_kv = nullptr; + }; + + using TmaLoadQLatent = typename CollectiveMmaQK::Params::TMA_A; + using TmaLoadQRope = typename CollectiveMmaQK::Params::TMA_A; + using TmaLoadCLatent = typename CollectiveMmaQK::Params::TMA_B; + using TmaLoadKRope = typename CollectiveMmaQK::Params::TMA_B; + using TmaLoadCLatentTranspose = typename CollectiveMmaPV::Params::TMA_B; + + struct MainloopParams { + TmaLoadQLatent tma_load_q_latent; + TmaLoadQRope tma_load_q_rope; + TmaLoadCLatent tma_load_c_latent; + TmaLoadKRope tma_load_k_rope; + TmaLoadCLatentTranspose tma_load_c_latent_transpose; + }; + + struct EpilogueParams { + ElementOut* ptr_o = nullptr; + ElementAcc* ptr_o_acc = nullptr; + TensorStride stride_o; + TensorStride stride_o_acc; + ElementLSE* ptr_lse = nullptr; + ElementLSE* ptr_lse_acc = nullptr; + Stride<_1, int> stride_lse; + Stride<_1, int> stride_lse_acc; + ElementAcc output_scale = 1.0f; + }; + + struct Params { + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueParams epilogue; + MainloopParams mainloop_params; + typename TileScheduler::Params tile_scheduler; + int split_kv = -1; + int* ptr_split_kv = nullptr; + }; + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + //workspace = nullptr; // let's get an error if one of these needs workspace + + auto [H, K, D, B] = args.problem_shape; + auto [L, R] = D; + + int paged_B = B; + int paged_K = K; + if (args.mainloop.ptr_page_table != nullptr) { + paged_B = args.mainloop.page_count; + paged_K = args.mainloop.page_size; + } + + auto params_qk_latent = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, K, L, B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_latent, args.mainloop.stride_q_latent, + args.mainloop.ptr_c_latent, args.mainloop.stride_c_latent, + }, nullptr); + + auto params_qk_latent_paged = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, paged_K, L, paged_B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_latent, args.mainloop.stride_q_latent, + args.mainloop.ptr_c_latent, args.mainloop.stride_c_latent, + }, nullptr); + + auto params_qk_rope = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, K, R, B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_rope, args.mainloop.stride_q_rope, + args.mainloop.ptr_k_rope, args.mainloop.stride_k_rope, + }, nullptr); + + auto params_qk_rope_paged = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, paged_K, R, paged_B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_rope, args.mainloop.stride_q_rope, + args.mainloop.ptr_k_rope, args.mainloop.stride_k_rope, + }, nullptr); + + + auto stride_c_latent_transpose = select<1,0,2>(args.mainloop.stride_c_latent); + auto params_pv_latent = CollectiveMmaPV::to_underlying_arguments( + make_shape(H, L, paged_K, paged_B), + typename CollectiveMmaPV::Arguments { + args.mainloop.ptr_q_latent, args.mainloop.stride_q_latent, // dummy, never used + args.mainloop.ptr_c_latent, stride_c_latent_transpose, + }, nullptr); + + MainloopParams mainloop_params { + params_qk_latent.tma_load_a, + params_qk_rope.tma_load_a, + params_qk_latent_paged.tma_load_b, + params_qk_rope_paged.tma_load_b, + params_pv_latent.tma_load_b + }; + + EpilogueParams epilogue_params; + + epilogue_params.ptr_o = args.epilogue.ptr_o; + epilogue_params.stride_o = args.epilogue.stride_o; + epilogue_params.ptr_lse = args.epilogue.ptr_lse; + epilogue_params.stride_lse = args.epilogue.stride_lse; + epilogue_params.output_scale = args.epilogue.output_scale; + + if (args.split_kv > 1) { + ElementAcc* ptr_o_acc = reinterpret_cast(workspace); + ElementLSE* ptr_lse_acc = reinterpret_cast(ptr_o_acc + H * L * args.split_kv * B); + epilogue_params.ptr_o_acc = ptr_o_acc; + epilogue_params.ptr_lse_acc = ptr_lse_acc; + + epilogue_params.stride_o_acc = make_tuple(static_cast(0 + L) * args.split_kv, _1{}, static_cast(0 + H * L) * args.split_kv); + epilogue_params.stride_lse_acc = make_tuple(_1{}, (0 + H) * args.split_kv); + } + + return {args.problem_shape, args.mainloop, epilogue_params, mainloop_params, + TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, args.split_kv), args.split_kv, args.ptr_split_kv}; + } + + static size_t get_workspace_size(Arguments const& args) { + ProblemShape problem_shape = args.problem_shape; + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + auto split_kv = args.split_kv; + return (sizeof(ElementAcc) * D_latent + sizeof(ElementLSE)) * H * split_kv * B; + } + static Status initialize_workspace( + Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.tile_scheduler); + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static bool can_implement(Arguments const& args) { + if (kIsCpAsync) { + if ((args.mainloop.page_size & (args.mainloop.page_size - 1)) != 0) { + return false; + } + if (args.mainloop.page_size > TileShapeS{}) { + return false; + } + } + else { + if (args.mainloop.ptr_page_table != nullptr && args.mainloop.page_size != TileShapeS{}) { + return false; + } + } + if (get<0>(args.problem_shape) != 128) { + return false; + } + if (get<1>(args.problem_shape) <= 0) { + return false; + } + if (args.split_kv <= 0) { + return false; + } + return true; + } + + + CUTLASS_DEVICE void operator()(Params const& params, char* smem_raw) { + + TileScheduler tile_scheduler(params.tile_scheduler); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_role(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + uint32_t cta_rank_in_cluster = cute::block_rank_in_cluster(); + int cta_coord_v = cta_rank_in_cluster % size<0>(AtomThrShapeMNK{}); + bool is_mma_leader_cta = cta_coord_v == 0; + + if (role == WarpRole::kLoad && lane_predicate && ! kIsCpAsync) { + prefetch_tma_descriptor(params.mainloop_params.tma_load_q_latent.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_c_latent.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_q_rope.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_k_rope.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_c_latent_transpose.get_tma_descriptor()); + } + SharedStorage& shared_storage = *reinterpret_cast(smem_raw); + + typename PipelineLoadQK::Params pipeline_load_qk_params; + if (role == WarpRole::kLoad) { + pipeline_load_qk_params.role = PipelineLoadQK::ThreadCategory::Producer; + } + if (role == WarpRole::kMma) { + pipeline_load_qk_params.role = PipelineLoadQK::ThreadCategory::Consumer; + } + if constexpr (kIsCpAsync) { + // we can make our life easier by unconditionally loading blocks + // since we know it'll always be legal + pipeline_load_qk_params.producer_arv_count = kNumLoadWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + } + else { + pipeline_load_qk_params.is_leader = lane_predicate && (role == WarpRole::kLoad) && is_mma_leader_cta; + pipeline_load_qk_params.transaction_bytes = kTransactionsBytesLoadQK; + } + pipeline_load_qk_params.initializing_warp = 0; + PipelineLoadQK pipeline_load_qk(shared_storage.pipelines.load_qk, pipeline_load_qk_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineS::Params pipeline_mma_s_params; + if (role == WarpRole::kMma) { + pipeline_mma_s_params.role = PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::kCompute) { + pipeline_mma_s_params.role = PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + pipeline_mma_s_params.initializing_warp = 1; + PipelineS pipeline_mma_s( + shared_storage.pipelines.mma_s, + pipeline_mma_s_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineP::Params pipeline_p_mma_params; + if (role == WarpRole::kMma) { + pipeline_p_mma_params.role = PipelineP::ThreadCategory::Consumer; + } + if (role == WarpRole::kCompute) { + pipeline_p_mma_params.role = PipelineP::ThreadCategory::Producer; + } + pipeline_p_mma_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + pipeline_p_mma_params.consumer_arv_count = 1; + pipeline_p_mma_params.initializing_warp = 2; + PipelineP pipeline_p_mma( + shared_storage.pipelines.p_mma, + pipeline_p_mma_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineO::Params pipeline_mma_o_params; + if (role == WarpRole::kMma) { + pipeline_mma_o_params.role = PipelineO::ThreadCategory::Producer; + } + if (role == WarpRole::kCompute) { + pipeline_mma_o_params.role = PipelineO::ThreadCategory::Consumer; + } + pipeline_mma_o_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + pipeline_mma_o_params.initializing_warp = 3; + PipelineO pipeline_mma_o( + shared_storage.pipelines.mma_o, + pipeline_mma_o_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelinePT::Params pipeline_pt_params; + if (role == WarpRole::kLoad) { + pipeline_pt_params.role = PipelinePT::ThreadCategory::Consumer; + } + if (role == WarpRole::kLoadPageTable) { + pipeline_pt_params.role = PipelinePT::ThreadCategory::Producer; + } + pipeline_pt_params.consumer_arv_count = kNumLoadWarps * cutlass::NumThreadsPerWarp; + pipeline_pt_params.producer_arv_count = cutlass::NumThreadsPerWarp; + pipeline_pt_params.initializing_warp = 4; + PipelinePT pipeline_page_table( + shared_storage.pipelines.load_page_table, + pipeline_pt_params); + + TmemAllocator tmem_allocator; + + pipeline_init_arrive_relaxed(size(ClusterShape{})); + + pipeline_load_qk.init_masks(ClusterShape{}); // do we need an update here for 2Sm? + pipeline_mma_s.init_masks(ClusterShape{}); + pipeline_p_mma.init_masks(ClusterShape{}); + pipeline_mma_o.init_masks(ClusterShape{}); + + typename PipelineLoadQK::PipelineState pipeline_load_qk_consumer_state; + typename PipelineLoadQK::PipelineState pipeline_load_qk_producer_state = cutlass::make_producer_start_state(); + + typename PipelineS::PipelineState pipeline_mma_s_consumer_state; + typename PipelineS::PipelineState pipeline_mma_s_producer_state = cutlass::make_producer_start_state(); + + typename PipelineP::PipelineState pipeline_p_mma_consumer_state; + typename PipelineP::PipelineState pipeline_p_mma_producer_state = cutlass::make_producer_start_state(); + + typename PipelineO::PipelineState pipeline_mma_o_consumer_state; + typename PipelineO::PipelineState pipeline_mma_o_producer_state = cutlass::make_producer_start_state(); + + typename PipelinePT::PipelineState pipeline_pt_consumer_state; + typename PipelinePT::PipelineState pipeline_pt_producer_state = cutlass::make_producer_start_state(); + + pipeline_init_wait(size(ClusterShape{})); + + if (role == WarpRole::kLoadPageTable) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_page_table( + blk_coord, + problem_shape, + params.mainloop, + shared_storage.tensors, + pipeline_page_table, pipeline_pt_producer_state, + local_split_kv + ); + } + } + else if (role == WarpRole::kLoad) { + if constexpr (kIsCpAsync) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_cpasync( + blk_coord, + problem_shape, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_producer_state, + local_split_kv, + /* must be shared pipe */ + pipeline_page_table, pipeline_pt_consumer_state + ); + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); + } + } + else { + if (params.mainloop.ptr_page_table != nullptr) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_tma( + blk_coord, + problem_shape, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_producer_state, + pipeline_load_qk, pipeline_load_qk_producer_state, + local_split_kv + ); + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); + } + } + else { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_tma( + blk_coord, + problem_shape, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_producer_state, + pipeline_load_qk, pipeline_load_qk_producer_state, + local_split_kv + ); + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); + } + } + } + } + else if (role == WarpRole::kMma) { + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + + if (is_mma_leader_cta) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + mma(blk_coord, + problem_shape, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_consumer_state, + pipeline_load_qk, pipeline_load_qk_consumer_state, + pipeline_mma_s, pipeline_mma_s_producer_state, + pipeline_p_mma, pipeline_p_mma_consumer_state, + pipeline_mma_o, pipeline_mma_o_producer_state, + local_split_kv + ); + } + } + + //cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive_and_wait(); + + //uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + //tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + else if (role == WarpRole::kCompute) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto split_kv = params.split_kv; + auto local_split_kv = split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + compute( + blk_coord, + problem_shape, + params.mainloop, // for softmax_scale + params.epilogue, + shared_storage.tensors, // for smem_comm + pipeline_mma_s, pipeline_mma_s_consumer_state, + pipeline_p_mma, pipeline_p_mma_producer_state, + pipeline_mma_o, pipeline_mma_o_consumer_state, + local_split_kv + ); + } + + //cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive(); + } + + cute::cluster_sync(); + cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive(); + if (role == WarpRole::kMma) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + } + + template + CUTLASS_DEVICE void load_page_table( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + TensorStorage& shared_tensors, + PipelinePT& pipeline_page_table, + typename PipelinePT::PipelineState& pipeline_pt_producer_state, int const& split_kv) { + + auto [H, K, D, B] = problem_shape; + int batch_coord = get<2>(blk_coord); + + auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), + make_shape(mainloop_args.page_count, B), + mainloop_args.stride_page_table); + auto mPT = mPT_l(_, batch_coord); + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + auto page_size = Pow2{mainloop_args.page_size}; + auto pages_per_tile = Pow2{TileShapeS{} / page_size}; + int thread_idx = threadIdx.x % cutlass::NumThreadsPerWarp; + +#if 1 + for (; k_tile_count > 0; ++k_index, --k_tile_count) { + pipeline_page_table.producer_acquire(pipeline_pt_producer_state); + + // assume a single warp + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < TileShapeS{}; i += cutlass::NumThreadsPerWarp) { + int idx = i + thread_idx; + bool guard = idx < pages_per_tile; + int smem_idx = pipeline_pt_producer_state.index() * TileShapeS::value + idx; + int pt_idx = pages_per_tile * k_index + idx; + + cutlass::arch::cp_async_zfill( + &shared_tensors.smem_page_table[smem_idx], &mPT(pt_idx), guard + ); + } + + pipeline_page_table.producer_commit(pipeline_pt_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_pt_producer_state; + } +#endif + } + + + struct Gather { + int& page_table_stage; + Pow2 pages_per_tile; + const int * __restrict__ smem_page_table; + + CUTLASS_DEVICE int operator()(int idx) const { + return smem_page_table[page_table_stage * TileShapeS::value + idx % pages_per_tile]; + } + + CUTLASS_DEVICE friend void print(Gather const&) { + printf(""); + } + + }; + + + template + CUTLASS_DEVICE void load_cpasync( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineLoadQK& pipeline_load, + typename PipelineLoadQK::PipelineState& pipeline_load_producer_state, + int const& split_kv, + PipelinePT& pipeline_page_table, + typename PipelinePT::PipelineState& pipeline_pt_consumer_state) { + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + using X = Underscore; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + // partition all tensors + auto mQL = make_tensor(make_gmem_ptr(mainloop_args.ptr_q_latent), make_shape(H, D_latent, B), mainloop_args.stride_q_latent); + auto mQR = make_tensor(make_gmem_ptr(mainloop_args.ptr_q_rope), make_shape(H, D_rope, B), mainloop_args.stride_q_rope); + + int paged_B = mainloop_args.page_count; + auto paged_K = Pow2{mainloop_args.page_size}; + auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), make_shape(paged_B, B), mainloop_args.stride_page_table); + + int batch_coord = get<2>(blk_coord); + auto mPT = mPT_l(_, batch_coord); + + auto gQL = local_tile(mQL, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gQR = local_tile(mQR, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + + ThrMMA cta_mma_qk = TiledMmaQK{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + ThrMMA cta_mma_pv = TiledMmaPV{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + + auto tSgQL = cta_mma_qk.partition_A(gQL); + auto tSgQR = cta_mma_qk.partition_A(gQR); + + Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); + Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); + + auto make_copy_for = [](auto sT) { + auto rT_a = sT.layout()(_, _, _, _0{}); + auto rT = make_ordered_layout(shape(rT_a), stride(rT_a)); + auto threads = Int{}; + auto values = Int{}; + return make_cotiled_copy( + Copy_Atom, Element>{}, + make_ordered_layout( + make_shape(threads, values), + make_stride(_1{}, _0{})), + rT); + }; + + // like cute::copy, but makes sure we do all page table lookups first + auto copy_split = [](auto atom, auto src, auto dst) { + auto src_v = group_modes<1, rank_v>(src); + auto dst_v = group_modes<1, rank_v>(dst); + + auto src_v_ptrs = make_tensor(size<1>(src_v)); + for (int i = 0; i < size<1>(src_v); i++) { + src_v_ptrs(i) = &src_v(_0{}, i); + } + + + for (int i = 0; i < size<1>(src_v); i++) { + auto src_v_i = make_tensor( + make_gmem_ptr(src_v_ptrs(i)), + make_shape(shape<0>(src_v)), + make_stride(make_stride(_1{}, _0{})) + ); + atom.call(src_v_i, dst_v(_, i)); + } + }; + + auto tiled_copy_q = make_copy_for(sQ); + auto tiled_copy_kc = make_copy_for(sKC); + auto tiled_copy_vc = make_copy_for(sVC); + + auto thr_copy_q = tiled_copy_q.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); + auto thr_copy_kc = tiled_copy_kc.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); + auto thr_copy_vc = tiled_copy_vc.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); + + auto tQsQ = thr_copy_q.partition_D(sQ); + auto tQgQL = thr_copy_q.partition_S(tSgQL); + auto tQgQR = thr_copy_q.partition_S(tSgQR); + + auto tKCsKC = thr_copy_kc.partition_D(sKC); + auto tVCsVC = thr_copy_vc.partition_D(sVC); + + auto pipeline_pt_release_state = pipeline_pt_consumer_state; + + int page_table_stage = -1; + Pow2 pages_per_tile{TileShapeS{} / paged_K}; + const int * __restrict__ smem_page_table = shared_tensors.smem_page_table.begin(); + Gather gather{page_table_stage, pages_per_tile, smem_page_table}; + + auto mCL = make_tensor( + make_gmem_ptr(mainloop_args.ptr_c_latent), + ComposedLayout{ + make_layout( + make_shape(make_shape(paged_K, paged_B), _1{}), + make_stride(make_stride(get<0>(mainloop_args.stride_c_latent), example::CustomStride(gather, get<2>(mainloop_args.stride_c_latent))), get<1>(mainloop_args.stride_c_latent))), + make_coord(_0{}, _0{}), + make_identity_layout(make_shape(paged_K * paged_B, D_latent))}); + + auto mKR = make_tensor( + make_gmem_ptr(mainloop_args.ptr_k_rope), + ComposedLayout{ + make_layout( + make_shape(make_shape(paged_K, paged_B), _1{}), + make_stride(make_stride(get<0>(mainloop_args.stride_k_rope), example::CustomStride(gather, get<2>(mainloop_args.stride_k_rope))), get<1>(mainloop_args.stride_k_rope))), + make_coord(_0{}, _0{}), + make_identity_layout(make_shape(paged_K * paged_B, D_latent))}); + + auto mCLT = make_tensor( + make_gmem_ptr(mainloop_args.ptr_c_latent), + ComposedLayout{ + make_layout( + make_shape(_1{}, make_shape(paged_K, paged_B)), + make_stride(get<1>(mainloop_args.stride_c_latent), make_stride(get<0>(mainloop_args.stride_c_latent), example::CustomStride(gather, get<2>(mainloop_args.stride_c_latent))))), + make_coord(_0{}, _0{}), + make_identity_layout(make_shape(D_latent, paged_K * paged_B))}); + + auto gCL = local_tile(mCL, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gKR = local_tile(mKR, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gCLT = local_tile(mCLT, TileShapePV{}, make_coord(_,_,_), Step{}); + + auto tSgCL = cta_mma_qk.partition_B(gCL); + auto tSgKR = cta_mma_qk.partition_B(gKR); + auto tOgCLT = cta_mma_pv.partition_B(gCLT); + + auto tKCgCL = thr_copy_kc.partition_S(tSgCL); + auto tKCgKR = thr_copy_kc.partition_S(tSgKR); + auto tVCgCLT = thr_copy_vc.partition_S(tOgCLT); + + // latent is first in memory, so let's load it first always + // startup: alternate Q and K, set tx count appropriately, for k_idx = 0 + auto& pipeline_acquire_state = pipeline_load_producer_state; + auto pipeline_commit_state = pipeline_acquire_state; + int pipeline_offset = 0; + + for (int i = 0; i < StagesPV; i++) { + cutlass::arch::cp_async_fence(); + } + + auto load_stage = [&](auto fn) { + pipeline_load.producer_acquire(pipeline_acquire_state); + fn(pipeline_acquire_state.index()); + cutlass::arch::cp_async_fence(); + + ++pipeline_acquire_state; + ++pipeline_offset; + + if (pipeline_offset == StagesPV - 1) { + cutlass::arch::cp_async_wait(); + pipeline_load.producer_commit(pipeline_commit_state); + ++pipeline_commit_state; + --pipeline_offset; + } + }; + + pipeline_page_table.consumer_wait(pipeline_pt_consumer_state); + page_table_stage = pipeline_pt_consumer_state.index(); + ++pipeline_pt_consumer_state; + + // each Q/K tile consists of rope and latent + for (int i = 0; i < IterationsQKLatent; i++) { + load_stage([&](int index) { + cute::copy(tiled_copy_q, tQgQL(_, _, _, _, _0{}, i, batch_coord), tQsQ(_, _, _, _, i)); + copy_split(tiled_copy_kc, tKCgCL(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + for (int i = 0; i < IterationsQKRope; i++) { + load_stage([&](int index) { + cute::copy(tiled_copy_q, tQgQR(_, _, _, _, _0{}, i, batch_coord), tQsQ(_, _, _, _, IterationsQKLatent + i)); + copy_split(tiled_copy_kc, tKCgKR(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + k_index += 1; + k_tile_count -= 1; + + // assume k_tile_count >= 1 + // perform K+Q load here + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + pipeline_page_table.consumer_wait(pipeline_pt_consumer_state); + page_table_stage = pipeline_pt_consumer_state.index(); + ++pipeline_pt_consumer_state; + + for (int i = 0; i < IterationsQKLatent; i++) { + load_stage([&](int index) { + copy_split(tiled_copy_kc, tKCgCL(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + for (int i = 0; i < IterationsQKRope; i++) { + load_stage([&](int index) { + copy_split(tiled_copy_kc, tKCgKR(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + page_table_stage = pipeline_pt_release_state.index(); + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + load_stage([&](int index) { + copy_split(tiled_copy_vc, tVCgCLT(_, _, _, _, j, IterationsPV_K * (k_index - 1) + i), tVCsVC(_, _, _, _, index)); + }); + } + } + + pipeline_page_table.consumer_release(pipeline_pt_release_state); + ++pipeline_pt_release_state; + + k_index += 1; + k_tile_count -= 1; + } + + page_table_stage = pipeline_pt_release_state.index(); + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + load_stage([&](int index) { + copy_split(tiled_copy_vc, tVCgCLT(_, _, _, _, j, IterationsPV_K * (k_index - 1) + i), tVCsVC(_, _, _, _, index)); + }); + } + } + + pipeline_page_table.consumer_release(pipeline_pt_release_state); + ++pipeline_pt_release_state; + + while (pipeline_offset > 0) { + cutlass::arch::cp_async_fence(); + + cutlass::arch::cp_async_wait(); + pipeline_load.producer_commit(pipeline_commit_state); + ++pipeline_commit_state; + --pipeline_offset; + } + + cutlass::arch::cp_async_wait<0>(); + + } + + + template + CUTLASS_DEVICE void load_tma( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineLoadQK& pipeline_load_qk, + typename PipelineLoadQK::PipelineState& pipeline_load_qk_producer_state, + PipelineLoadPV& pipeline_load_pv, + typename PipelineLoadPV::PipelineState& pipeline_load_pv_producer_state, + int const& split_kv) { + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + using X = Underscore; + + // partition all tensors + auto mQL = mainloop_params.tma_load_q_latent.get_tma_tensor(make_shape(H, D_latent, B)); + auto mQR = mainloop_params.tma_load_q_rope.get_tma_tensor(make_shape(H, D_rope, B)); + + int paged_B = B; + int paged_K = K; + if constexpr (kIsPaged) { + paged_B = mainloop_args.page_count; + paged_K = mainloop_args.page_size; + } + auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), make_shape(paged_B, B), mainloop_args.stride_page_table); + + auto mCL = mainloop_params.tma_load_c_latent.get_tma_tensor(make_shape(paged_K, D_latent, paged_B)); + auto mKR = mainloop_params.tma_load_k_rope.get_tma_tensor(make_shape(paged_K, D_rope, paged_B)); + + auto mCLT = mainloop_params.tma_load_c_latent_transpose.get_tma_tensor(make_shape(D_latent, paged_K, paged_B)); + + auto gQL = local_tile(mQL, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gQR = local_tile(mQR, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + + auto gCL = local_tile(mCL, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gKR = local_tile(mKR, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gCLT = local_tile(mCLT, TileShapePV{}, make_coord(_,_,_), Step{}); + + ThrMMA cta_mma_qk = TiledMmaQK{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + ThrMMA cta_mma_pv = TiledMmaPV{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + + auto tSgQL = cta_mma_qk.partition_A(gQL); + auto tSgQR = cta_mma_qk.partition_A(gQR); + + auto tSgCL = cta_mma_qk.partition_B(gCL); + auto tSgKR = cta_mma_qk.partition_B(gKR); + + auto tOgCLT = cta_mma_pv.partition_B(gCLT); + + Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); + Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); + + auto [tQLgQL_mkl, tQsQ] = tma_partition( + mainloop_params.tma_load_q_latent, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSgQL)); + + auto [tQRgQR_mkl, tQsQ_ignore] = tma_partition( + mainloop_params.tma_load_q_rope, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSgQR)); + + auto [tCLgCL_nkl, tKCsKC] = tma_partition( + mainloop_params.tma_load_c_latent, _0{}, make_layout(_1{}), + group_modes<0,3>(sKC), group_modes<0,3>(tSgCL)); + + auto [tKRgKR_nkl, tKCsKC_ignore] = tma_partition( + mainloop_params.tma_load_k_rope, _0{}, make_layout(_1{}), + group_modes<0,3>(sKC), group_modes<0,3>(tSgKR)); + + auto [tCLTgCLT_nkl, tVCsVC] = tma_partition( + mainloop_params.tma_load_c_latent_transpose, _0{}, make_layout(_1{}), + group_modes<0,3>(sVC), group_modes<0,3>(tOgCLT)); + + uint16_t mcast_mask = 0; + + int batch_coord = get<2>(blk_coord); + Tensor tQLgQL = tQLgQL_mkl(_, _, _, batch_coord); + Tensor tQRgQR = tQRgQR_mkl(_, _, _, batch_coord); + + auto mPT = mPT_l(_, batch_coord); + + Tensor tCLgCL = tCLgCL_nkl(_, _, _, _); + Tensor tKRgKR = tKRgKR_nkl(_, _, _, _); + + // careful: stage and k are swapped here! + Tensor tCLTgCLT = tCLTgCLT_nkl(_, _, _, _); + + // latent is first in memory, so let's load it first always + // startup: alternate Q and K, set tx count appropriately, for k_idx = 0 + + // each Q/K tile consists of rope and latent + for (int i = 0; i < IterationsQKLatent; i++) { + pipeline_load_qk.producer_expect_transaction(pipeline_load_qk_producer_state, kTransactionsBytesLoadExtraQ); + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // expect the extra bytes + // load_qk ql + cute::copy(mainloop_params.tma_load_q_latent.with(*tma_barrier, mcast_mask), tQLgQL(_, _0{}, i), tQsQ(_, i)); + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + for (int i = 0; i < IterationsQKRope; i++) { + pipeline_load_qk.producer_expect_transaction(pipeline_load_qk_producer_state, kTransactionsBytesLoadExtraQ); + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // expect the extra bytes + // load_qk ql + cute::copy(mainloop_params.tma_load_q_rope.with(*tma_barrier, mcast_mask), tQRgQR(_, _0{}, i), tQsQ(_, i + IterationsQKLatent)); + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + k_index += 1; + k_tile_count -= 1; + + // assume k_tile_count >= 1 + // perform K+Q load here + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + // perform K load + for (int i = 0; i < IterationsQKLatent; i++) { + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + for (int i = 0; i < IterationsQKRope; i++) { + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + // prefetch next K load to keep busy while we transpose-load from cache + const int kPrefetchDistance = 1; + for (int i = 0; i < IterationsQKLatent; i++) { + if (cute::elect_one_sync()) { + if constexpr (kIsPaged) { + if (k_tile_count > kPrefetchDistance) { + cute::prefetch( + mainloop_params.tma_load_c_latent, + tCLgCL(_, _0{}, i, mPT(k_index + kPrefetchDistance)) + ); + } + } + else { + cute::prefetch( + mainloop_params.tma_load_c_latent, + tCLgCL(_, k_index + kPrefetchDistance, i, batch_coord) + ); + } + } + } + + for (int i = 0; i < IterationsQKRope; i++) { + if (cute::elect_one_sync()) { + if constexpr (kIsPaged) { + if (k_tile_count > kPrefetchDistance) { + cute::prefetch( + mainloop_params.tma_load_k_rope, + tKRgKR(_, _0{}, i, mPT(k_index + kPrefetchDistance)) + ); + } + } + else { + cute::prefetch( + mainloop_params.tma_load_k_rope, + tKRgKR(_, k_index + kPrefetchDistance, i, batch_coord) + ); + } + } + } + + // perform V load (k_idx - 1) + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.producer_acquire(pipeline_load_pv_producer_state); + auto tma_barrier = pipeline_load_pv.producer_get_barrier(pipeline_load_pv_producer_state); + + if (cute::elect_one_sync()) { + // load_pv cl + // note the transpose in indices! + // note we are off-by-one on k_index + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, i, mPT(k_index - 1)), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, IterationsPV_K * (k_index - 1) + i, batch_coord), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + } + ++pipeline_load_pv_producer_state; + } + } + + k_index += 1; + k_tile_count -= 1; + } + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.producer_acquire(pipeline_load_pv_producer_state); + auto tma_barrier = pipeline_load_pv.producer_get_barrier(pipeline_load_pv_producer_state); + + if (cute::elect_one_sync()) { + // load_pv cl + // note the transpose in indices + // note we are off-by-one on k_index + + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, i, mPT(k_index - 1)), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, IterationsPV_K * (k_index - 1) + i, batch_coord), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + } + ++pipeline_load_pv_producer_state; + } + } + } + + template + CUTLASS_DEVICE void mma( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + TensorStorage& shared_tensors, + PipelineLoadQK& pipeline_load_qk, + typename PipelineLoadQK::PipelineState& pipeline_load_qk_consumer_state, + PipelineLoadPV& pipeline_load_pv, + typename PipelineLoadPV::PipelineState& pipeline_load_pv_consumer_state, + PipelineS& pipeline_mma_s, + typename PipelineS::PipelineState& pipeline_mma_s_producer_state, + PipelineP& pipeline_p_mma, + typename PipelineP::PipelineState& pipeline_p_mma_consumer_state, + PipelineO& pipeline_mma_o, + typename PipelineO::PipelineState& pipeline_mma_o_producer_state, + int const& split_kv) { + + auto [H, K, D, B] = problem_shape; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + // mma init + Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); + Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); + Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{}); + + Tensor tSrQ = TiledMmaQK::make_fragment_A(sQ); + Tensor tSrKC = TiledMmaQK::make_fragment_B(sKC); + Tensor tOrP = TiledMmaPV::make_fragment_A(sP); + Tensor tOrVC = TiledMmaPV::make_fragment_B(sVC); + + TiledMmaQK tiled_mma_qk; + TiledMmaPV tiled_mma_pv; + + Tensor tStS = partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{})); + Tensor tOtO = partition_fragment_C(tiled_mma_pv, select<0,1>(TileShapePV{})); + + tiled_mma_pv.accumulate_ = UMMA::ScaleOut::Zero; + + pipeline_mma_s.producer_acquire(pipeline_mma_s_producer_state); + + // Mma S0 S1 O0 S2 O1 ... Sn On-1 On + // S0 ownership -- ----- -- -- + // S1 ownership -- ----- ---- + // O ownership -- -- ---- -- + + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + for (int i = 0; i < IterationsQK; i++) { + pipeline_load_qk.consumer_wait(pipeline_load_qk_consumer_state); + int read_stage = pipeline_load_qk_consumer_state.index(); + + tStS.data() = uint32_t(pipeline_mma_s_producer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSrQ); ++k_block) { + cute::gemm(tiled_mma_qk, + tSrQ(_,_,k_block,i), + tSrKC(_,_,k_block,read_stage), + tStS); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_qk.consumer_release(pipeline_load_qk_consumer_state); + ++pipeline_load_qk_consumer_state; + } + + pipeline_mma_s.producer_commit(pipeline_mma_s_producer_state); + ++pipeline_mma_s_producer_state; + + k_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + pipeline_mma_s.producer_acquire(pipeline_mma_s_producer_state); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + for (int i = 0; i < IterationsQK; i++) { + pipeline_load_qk.consumer_wait(pipeline_load_qk_consumer_state); + int read_stage = pipeline_load_qk_consumer_state.index(); + + tStS.data() = uint32_t(pipeline_mma_s_producer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSrQ); ++k_block) { + cute::gemm(tiled_mma_qk, + tSrQ(_,_,k_block,i), + tSrKC(_,_,k_block,read_stage), + tStS); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_qk.consumer_release(pipeline_load_qk_consumer_state); + ++pipeline_load_qk_consumer_state; + } + + pipeline_mma_s.producer_commit(pipeline_mma_s_producer_state); + ++pipeline_mma_s_producer_state; + + pipeline_mma_o.producer_acquire(pipeline_mma_o_producer_state); + pipeline_p_mma.consumer_wait(pipeline_p_mma_consumer_state); + + for (int i = 0; i < IterationsPV_K; i++) { + auto acc_flag = tiled_mma_pv.accumulate_; + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.consumer_wait(pipeline_load_pv_consumer_state); + + int read_stage = pipeline_load_pv_consumer_state.index(); + + tOtO.data() = uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO); + tiled_mma_pv.accumulate_ = acc_flag; + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tOrP); ++k_block) { + cute::gemm(tiled_mma_pv, + tOrP(_,_,k_block, make_coord(i, pipeline_p_mma_consumer_state.index())), + tOrVC(_,_,k_block,read_stage), + tOtO); + tiled_mma_pv.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_pv.consumer_release(pipeline_load_pv_consumer_state); + ++pipeline_load_pv_consumer_state; + } + } + + pipeline_p_mma.consumer_release(pipeline_p_mma_consumer_state); + ++pipeline_p_mma_consumer_state; + pipeline_mma_o.producer_commit(pipeline_mma_o_producer_state); + ++pipeline_mma_o_producer_state; + + --k_tile_count; + } + + pipeline_mma_o.producer_acquire(pipeline_mma_o_producer_state); + pipeline_p_mma.consumer_wait(pipeline_p_mma_consumer_state); + + for (int i = 0; i < IterationsPV_K; i++) { + auto acc_flag = tiled_mma_pv.accumulate_; + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.consumer_wait(pipeline_load_pv_consumer_state); + + int read_stage = pipeline_load_pv_consumer_state.index(); + + tOtO.data() = uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO); + tiled_mma_pv.accumulate_ = acc_flag; + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tOrP); ++k_block) { + cute::gemm(tiled_mma_pv, + tOrP(_,_,k_block, make_coord(i, pipeline_p_mma_consumer_state.index())), + tOrVC(_,_,k_block,read_stage), + tOtO); + tiled_mma_pv.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_pv.consumer_release(pipeline_load_pv_consumer_state); + ++pipeline_load_pv_consumer_state; + } + } + + pipeline_p_mma.consumer_release(pipeline_p_mma_consumer_state); + ++pipeline_p_mma_consumer_state; + pipeline_mma_o.producer_commit(pipeline_mma_o_producer_state); + ++pipeline_mma_o_producer_state; + } + + + template + CUTLASS_DEVICE void softmax( + IsLastTile const& is_last_tile, + ElementAcc& row_max, + ElementAcc& row_sum, + ElementAcc& correction_factor, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + TensorStorage& shared_tensors, + int k_index, + uint32_t tmem_s, + int smem_p_index) { + + auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; + + TiledMmaQK tiled_mma_qk; + + Tensor tStS = partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{})); + tStS.data() = tmem_s; + + CUTE_STATIC_ASSERT_V(shape<1>(tStS) == _1{}); + CUTE_STATIC_ASSERT_V(shape<2>(tStS) == _1{}); + Tensor tAcc = tStS(make_coord(_,_),_0{},_0{}); + + Tensor cS = make_identity_tensor(take<0,2>(CtaShapeQK{})); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_cS = thread_t2r.partition_D(cS); + Tensor tTR_rAcc = make_tensor(shape(tTR_cS)); + + Tensor tTR_rS_frag = make_tensor(shape(tTR_rAcc)); + const int AlignmentS = 4; + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + Tensor tTR_rAcc_vec = recast>(tTR_rAcc); + Tensor tTR_rS_vec = recast>(tTR_rS_frag); + + // load s + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + if (is_last_tile) { + for (int i = 0; i < size(tTR_rAcc); i++) { + if (get<1>(tTR_cS(i)) + TileShapeS{} * k_index >= get<1>(problem_shape)) { + tTR_rAcc(i) = -std::numeric_limits::infinity(); + } + } + } + + // max + ElementAcc row_max_new = row_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i += 1) { + row_max_new = ::fmax(row_max_new, tTR_rAcc(i)); + } + + // for 2x2 dp, reduce here + if constexpr (kWarpsInN > 1) { + shared_tensors.smem_exchange[threadIdx.x] = row_max_new; + cutlass::arch::NamedBarrier(kNumComputeWarps*NumThreadsPerWarp, kNamedBarrierExchange).sync(); + // (64, 2) shape + int peer_index = (threadIdx.x + 64) % 128; + row_max_new = cutlass::max(row_max_new, shared_tensors.smem_exchange[peer_index]); + } + +#ifndef B2B + // find correction factor + ElementAcc softmax_scale_log2 = mainloop_args.softmax_scale * static_cast(M_LOG2E); + correction_factor = ::exp2f(softmax_scale_log2 * (row_max - row_max_new)); + row_max = row_max_new; + + // softmax + ElementAcc row_max_scale_log2 = row_max * softmax_scale_log2; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rAcc(i) = ::exp2f(softmax_scale_log2 * tTR_rAcc(i) - row_max_scale_log2); + } +#endif + + // quantize + cutlass::NumericArrayConverter epilogue_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc_vec); i++) { + tTR_rS_vec(i) = epilogue_op(tTR_rAcc_vec(i)); + } + + Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{})(_, _, _, make_coord(_, smem_p_index)); + + Tensor tOcP = TiledMmaPV{}.get_slice(_0{}).partition_A(cS); + + // have a mapping for each thread to coord + // find identical mapping to coords for the MMA + auto l = make_ordered_layout(make_shape(make_shape(_64{}, _2{}), make_shape(_16{}, TileShapeS{} / _32{})), make_stride(make_stride(_0{}, _3{}), make_stride(_1{}, _2{}))); + auto sP_ = as_position_independent_swizzle_tensor(sP); + copy_aligned(tTR_rS_frag, sP_.compose(l)(threadIdx.x, _)); + + // sum + row_sum *= correction_factor; + + static_assert(cute::is_same_v); + auto tTR_rAcc_float2 = recast(tTR_rAcc); + auto sums = make_tensor(_4{}); + static_assert(size(tTR_rAcc_float2) % size(sums) == 0); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(sums); i++) { + sums(i) = tTR_rAcc_float2(i); + } + CUTLASS_PRAGMA_UNROLL + for (int i = size(sums); i < size(tTR_rAcc_float2); i += size(sums)) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(sums); j++) { + cute::add(sums(j), sums(j), tTR_rAcc_float2(i + j)); + } + } + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < size(sums); i *= 2) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(sums); j += 2*i) { + cute::add(sums(j), sums(j), sums(j+i)); + } + } + row_sum += sums(0).x + sums(0).y; + } + + + CUTLASS_DEVICE void rescale( + ElementAcc correction_factor, + uint32_t tmem_o) { + + // for b2b gemm, do nothing +#ifndef B2B + auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; + auto store_op = TMEM::tmem_load_to_store(load_op); + + TiledMmaPV tiled_mma_pv; + + Tensor tOtO = partition_fragment_C(tiled_mma_pv, select<0,1>(TileShapePV{})); + tOtO.data() = tmem_o; + + CUTE_STATIC_ASSERT_V(shape<1>(tOtO) == _1{}); + CUTE_STATIC_ASSERT_V(shape<2>(tOtO) == _1{}); + Tensor tAcc = tOtO(make_coord(_,_),_0{},_0{}); + + auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}); + Tensor gO = make_tensor(make_gmem_ptr((ElementAcc*) nullptr), cta_tiler_pv, make_stride(0, 0)); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto tiled_r2t = make_tmem_copy(store_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + auto thread_r2t = tiled_r2t.get_slice(thread_idx); + Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); + + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + + // load o + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + // multiply by correction factor + float2 correction_factor_vec = make_float2(correction_factor, correction_factor); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i += 2) { + float2 in = make_float2(tTR_rAcc(i + 0), tTR_rAcc(i + 1)); + float2 out; + cute::mul(out, in, correction_factor_vec); + tTR_rAcc(i + 0) = out.x; + tTR_rAcc(i + 1) = out.y; + } + + // store o + copy(tiled_r2t, tTR_rAcc, tTR_tAcc); +#endif + } + + + template + CUTLASS_DEVICE void epilogue( + ElementAcc& row_max, + ElementAcc& row_sum, + BlkCoord const& cta_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueParams const& epilogue_args, + TensorStorage& shared_tensors, + uint32_t tmem_o, + int const& split_kv) { + + auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; + + TiledMmaPV tiled_mma_pv; + + Tensor tOtO = TiledMmaPV::make_fragment_C(partition_shape_C(TiledMmaPV{}, take<0, 2>(TileShapePV{}))); + tOtO.data() = tmem_o; + + CUTE_STATIC_ASSERT_V(shape<1>(tOtO) == _1{}); + CUTE_STATIC_ASSERT_V(shape<2>(tOtO) == _1{}); + Tensor tAcc = tOtO(make_coord(_,_),_0{},_0{}); + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + if (epilogue_args.ptr_o_acc != nullptr) { + using ElementOutAcc = ElementAcc; + constexpr auto AlignmentOutAcc = 128 / cute::sizeof_bits_v; + Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o_acc + get<3>(cta_coord) * D_latent), make_shape(H, D_latent, B), epilogue_args.stride_o_acc); + auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}); + Tensor gO = local_tile(mO, cta_tiler_pv, take<0,3>(cta_coord)); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); + + Tensor tTR_rO_frag = make_tensor(shape(tTR_rAcc)); + Tensor tTR_rO_src = recast>(coalesce(tTR_rO_frag)); + Tensor tR2G_rO_dst = recast>(coalesce(tTR_gO)); + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + cutlass::epilogue::thread::LinearCombination epilogue_op({epilogue_args.output_scale / row_sum}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rO_frag(i) = epilogue_op(tTR_rAcc(i)); + } + + copy(tTR_rO_src, tR2G_rO_dst); + +#ifndef B2B + + // compute LSE + ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max; + + // store LSE + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_acc + H * get<3>(cta_coord)), make_shape(H, B), epilogue_args.stride_lse_acc); + Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{}); + // for 2x2 dp, this must be conditional and the index is wrong + if (! kIs2Sm || (threadIdx.x < 64)) + { + gLSE(threadIdx.x) = lse; + } + #endif + } + else { + Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o), make_shape(H, D_latent, B), epilogue_args.stride_o); + auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}); + Tensor gO = local_tile(mO, cta_tiler_pv, take<0,3>(cta_coord)); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); + + Tensor tTR_rO_frag = make_tensor(shape(tTR_rAcc)); + Tensor tTR_rO_src = recast>(coalesce(tTR_rO_frag)); + Tensor tR2G_rO_dst = recast>(coalesce(tTR_gO)); + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + cutlass::epilogue::thread::LinearCombination epilogue_op({epilogue_args.output_scale / row_sum}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rO_frag(i) = epilogue_op(tTR_rAcc(i)); + } + + copy(tTR_rO_src, tR2G_rO_dst); + +#ifndef B2B + if (epilogue_args.ptr_lse != nullptr) { + // compute LSE + ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max; + + // store LSE + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse), make_shape(H, B), epilogue_args.stride_lse); + Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{}); + + // for 2x2 dp, this must be conditional and the index is wrong + if (! kIs2Sm || (threadIdx.x < 64)) + { + gLSE(threadIdx.x) = lse; + } + } +#endif + } + } + + + template + CUTLASS_DEVICE void compute( + CtaCoord const& cta_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueParams const& epilogue_args, + TensorStorage& shared_tensors, + PipelineS& pipeline_mma_s, + typename PipelineS::PipelineState& pipeline_mma_s_consumer_state, + PipelineP& pipeline_p_mma, + typename PipelineP::PipelineState& pipeline_p_mma_producer_state, + PipelineO& pipeline_mma_o, + typename PipelineO::PipelineState& pipeline_mma_o_consumer_state, + int const& split_kv) { + + auto [H, K, D, B] = problem_shape; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(cta_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + + // if we return early, we have to make sure we release the load warp + cutlass::arch::NamedBarrier( + (kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, + kNamedBarrierEpilogue + ).arrive(); + + return; + } + int k_index_final = k_tile_total - 1; + + ElementAcc row_max = -std::numeric_limits::infinity(); + ElementAcc row_sum = 0; + ElementAcc correction_factor = 1; + + pipeline_p_mma.producer_acquire(pipeline_p_mma_producer_state); + pipeline_mma_s.consumer_wait(pipeline_mma_s_consumer_state); + + auto dispatch_bool = [](bool b, auto fn) { + if (b) { + fn(cute::true_type{}); + } + else { + fn(cute::false_type{}); + } + }; + + // softmax s0 -> p0 + dispatch_bool(k_index == k_index_final, [&](auto is_last_tile) { + softmax( + is_last_tile, + row_max, row_sum, correction_factor, + problem_shape, mainloop_args, shared_tensors, k_index, + uint32_t(pipeline_mma_s_consumer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1), + pipeline_p_mma_producer_state.index() + ); + }); + + k_index += 1; + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::fence_view_async_shared(); + pipeline_mma_s.consumer_release(pipeline_mma_s_consumer_state); + ++pipeline_mma_s_consumer_state; + pipeline_p_mma.producer_commit(pipeline_p_mma_producer_state); + ++pipeline_p_mma_producer_state; + + k_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + pipeline_p_mma.producer_acquire(pipeline_p_mma_producer_state); + pipeline_mma_s.consumer_wait(pipeline_mma_s_consumer_state); + + // softmax s1 -> p1 + dispatch_bool(k_index == k_index_final, [&](auto is_last_tile) { + softmax( + is_last_tile, + row_max, row_sum, correction_factor, + problem_shape, mainloop_args, shared_tensors, k_index, + uint32_t(pipeline_mma_s_consumer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1), + pipeline_p_mma_producer_state.index() + ); + }); + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::fence_view_async_shared(); + pipeline_mma_s.consumer_release(pipeline_mma_s_consumer_state); + ++pipeline_mma_s_consumer_state; + pipeline_p_mma.producer_commit(pipeline_p_mma_producer_state); + ++pipeline_p_mma_producer_state; + + pipeline_mma_o.consumer_wait(pipeline_mma_o_consumer_state); + + // rescale + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < IterationsPV_N; j++) { + rescale(correction_factor, uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO)); + } + + cutlass::arch::fence_view_async_tmem_store(); + pipeline_mma_o.consumer_release(pipeline_mma_o_consumer_state); + ++pipeline_mma_o_consumer_state; + + --k_tile_count; + k_index += 1; + } + + pipeline_mma_o.consumer_wait(pipeline_mma_o_consumer_state); + +#ifdef B2B + row_sum = 1; +#else + if constexpr (kWarpsInN > 1) { + // reduce row_sum if needed (for 2x2 dp) + shared_tensors.smem_exchange[threadIdx.x] = row_sum; + cutlass::arch::NamedBarrier(kNumComputeWarps*NumThreadsPerWarp, kNamedBarrierExchange).sync(); + // (64, 2) shape + int peer_index = (threadIdx.x + 64) % 128; + row_sum += shared_tensors.smem_exchange[peer_index]; + } +#endif + + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive(); + + // epilogue + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < IterationsPV_N; j++) { + epilogue( + row_max, row_sum, + replace<1>(cta_coord, j), problem_shape, + mainloop_args, epilogue_args, shared_tensors, + uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO), split_kv + ); + } + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_o.consumer_release(pipeline_mma_o_consumer_state); + ++pipeline_mma_o_consumer_state; + } + +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel diff --git a/examples/77_blackwell_fmha/kernel/sm100_mla_tile_scheduler.hpp b/examples/77_blackwell_fmha/kernel/sm100_mla_tile_scheduler.hpp new file mode 100644 index 0000000000..dbcc2ce8b8 --- /dev/null +++ b/examples/77_blackwell_fmha/kernel/sm100_mla_tile_scheduler.hpp @@ -0,0 +1,160 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.h" + +namespace cutlass::fmha::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +struct Sm100MlaIndividualTileScheduler { + + struct Params { + dim3 grid; + }; + + bool valid_ = true; + + CUTLASS_DEVICE + Sm100MlaIndividualTileScheduler(Params const&) {} + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, int const& split_kv) { + using namespace cute; + dim3 grid(get<0>(cluster_shape), get<3>(problem_shape) /* Batch */, split_kv /*Maximum Split KV*/); + return Params{ grid }; + } + + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + return make_coord(blockIdx.x, _0{}, blockIdx.y, blockIdx.z); + } + + CUTLASS_DEVICE + Sm100MlaIndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct Sm100MlaPersistentTileScheduler { + + struct Params { + int num_blocks; + FastDivmod divmod_m_block; + FastDivmod divmod_b; + FastDivmod divmod_split_kv; + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + Sm100MlaPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {} + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, int const& split_kv) { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 1 || sm_count % size<0>(cluster_shape) != 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + hw_info.sm_count = sm_count; + + int num_m_blocks = size<0>(cluster_shape); + int num_blocks = num_m_blocks * get<3>(problem_shape) /* Batch */; + num_blocks *= split_kv; /* Maximum Split KV*/ + + return Params { + num_blocks, + { num_m_blocks}, { get<3>(problem_shape) }, {split_kv}, + hw_info + }; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return block_idx < params.num_blocks; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int m_block, bidb, n_split_kv; + params.divmod_m_block(block_decode, m_block, block_decode); + params.divmod_b(block_decode, bidb, block_decode); + params.divmod_split_kv(block_decode, n_split_kv, block_decode); + return make_coord(m_block, _0{}, bidb, n_split_kv); + } + + CUTLASS_DEVICE + Sm100MlaPersistentTileScheduler& operator++() { + block_idx += gridDim.x; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel + diff --git a/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp b/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp new file mode 100644 index 0000000000..bb8cfb348b --- /dev/null +++ b/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp @@ -0,0 +1,311 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + class TensorDQ, /* class TensorDK, class TensorDV, */ + class Fusion +> +void __global__ fmha_bwd_reference_dQ_kernel( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + TensorDQ mDQ, /* TensorDK mDK, TensorDV mDV, */ + Fusion fusion) { + + using namespace cute; + + using Element = typename TensorO::value_type; + using ElementAccumulator = typename TensorLSE::value_type; + + extern __shared__ char mS_mem[]; + Element* mS = reinterpret_cast(mS_mem); + + Element softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mO))); + + for (int idx_L = blockIdx.y; idx_L < size<2>(mDQ); idx_L += gridDim.y) { + for (int idx_Q = blockIdx.x; idx_Q < size<0>(mDQ); idx_Q += gridDim.x) { + for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) { + ElementAccumulator acc_qk = 0; + ElementAccumulator acc_dov = 0; + ElementAccumulator acc_doo = 0; + for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) { + acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L); + acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L); + acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L); + } // for idx_D0 + + auto id = make_identity_tensor(make_shape(1, 1)); + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc_qk; + fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + acc_qk = frag(0); + + mS[idx_K] = static_cast(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo)); + } // for idx_K + + __syncthreads(); + + for (int idx_D = threadIdx.x; idx_D < size<1>(mDQ); idx_D += blockDim.x) { + ElementAccumulator acc = 0; + for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) { + acc += mS[idx_K] * mK(idx_K, idx_D, idx_L); + } + mDQ(idx_Q, idx_D, idx_L) = static_cast(acc); + } // for idx_D + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /* class TensorDQ, */ class TensorDK, /* class TensorDV, */ + class Fusion +> +void __global__ fmha_bwd_reference_dK_kernel( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /* TensorDQ mDQ, */ TensorDK mDK, /* TensorDV mDV, */ + Fusion fusion) { + + using namespace cute; + + using Element = typename TensorO::value_type; + using ElementAccumulator = typename TensorLSE::value_type; + + extern __shared__ char mS_mem[]; + Element* mS = reinterpret_cast(mS_mem); + + Element softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mO))); + + for (int idx_L = blockIdx.y; idx_L < size<2>(mDK); idx_L += gridDim.y) { + for (int idx_K = blockIdx.x; idx_K < size<0>(mDK); idx_K += gridDim.x) { + for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) { + ElementAccumulator acc_qk = 0; + ElementAccumulator acc_dov = 0; + ElementAccumulator acc_doo = 0; + for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) { + acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L); + acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L); + acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L); + } // for idx_D0 + + auto id = make_identity_tensor(make_shape(1, 1)); + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc_qk; + fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + acc_qk = frag(0); + + mS[idx_Q] = static_cast(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo)); + } // for idx_Q + + __syncthreads(); + + for (int idx_D = threadIdx.x; idx_D < size<1>(mDK); idx_D += blockDim.x) { + ElementAccumulator acc = 0; + for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) { + acc += mS[idx_Q] * mQ(idx_Q, idx_D, idx_L); + } + mDK(idx_K, idx_D, idx_L) = static_cast(acc); + } // for idx_D + } // for idx_K + } // for idx_L +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /* class TensorDQ, class TensorDK, */ class TensorDV, + class Fusion +> +void __global__ fmha_bwd_reference_dV_kernel( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /* TensorDQ mDQ, TensorDK mDK, */ TensorDV mDV, + Fusion fusion) { + + using namespace cute; + + using Element = typename TensorO::value_type; + using ElementAcc = typename TensorLSE::value_type; + + extern __shared__ char mS_mem[]; + Element* mS = reinterpret_cast(mS_mem); + + ElementAcc softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mO))); + + for (int idx_L = blockIdx.y; idx_L < size<2>(mDV); idx_L += gridDim.y) { + for (int idx_K = blockIdx.x; idx_K < size<0>(mDV); idx_K += gridDim.x) { + for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) { + ElementAcc acc_qk = 0; + + for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) { + ElementAcc rQ = mQ(idx_Q, idx_D0, idx_L); + ElementAcc rK = mK(idx_K, idx_D0, idx_L); + acc_qk += rQ * rK; + } // for idx_D0 + + auto id = make_identity_tensor(make_shape(1, 1)); + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc_qk; + fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + acc_qk = frag(0); + + mS[idx_Q] = static_cast(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L))); + } // for idx_Q + + __syncthreads(); + + for (int idx_D = threadIdx.x; idx_D < size<1>(mDV); idx_D += blockDim.x) { + ElementAcc acc = 0; + for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) { + ElementAcc rS = mS[idx_Q]; + ElementAcc rDO = mDO(idx_Q, idx_D, idx_L); + acc += rS * rDO; + } + mDV(idx_K, idx_D, idx_L) = static_cast(acc); + } // for idx_D + } // for idx_K + } // for idx_L +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /**/ class TensorDQ, /** / class TensorDK, / ** / class TensorDV, / **/ + class Fusion +> +void fmha_bwd_reference_dQ( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /**/ TensorDQ mDQ, /** / TensorDK mDK, / ** / TensorDV mDV, / **/ + Fusion fusion) { + + using namespace cute; + + dim3 grid(size<0>(mDQ), size<2>(mDQ), 1); + dim3 block(256); + int shared_mem = size<0>(mK) * sizeof(typename TensorO::value_type); + fmha_bwd_reference_dQ_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /** / class TensorDQ, / **/ class TensorDK, /** / class TensorDV, / **/ + class Fusion +> +void fmha_bwd_reference_dK( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /** / TensorDQ mDQ, / **/ TensorDK mDK, /** / TensorDV mDV, / **/ + Fusion fusion) { + + using namespace cute; + + dim3 grid(size<0>(mDK), size<2>(mDK), 1); + dim3 block(256); + int shared_mem = size<0>(mDO) * sizeof(typename TensorO::value_type); + fmha_bwd_reference_dK_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /** / class TensorDQ, / ** / class TensorDK, / **/ class TensorDV, /**/ + class Fusion +> +void fmha_bwd_reference_dV( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /** / TensorDQ mDQ, / ** / TensorDK mDK, / **/ TensorDV mDV, /**/ + Fusion fusion) { + + using namespace cute; + + dim3 grid(size<0>(mDV), size<2>(mDV), 1); + dim3 block(256); + int shared_mem = size<0>(mDO) * sizeof(typename TensorO::value_type); + fmha_bwd_reference_dV_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + class TensorDQ, class TensorDK, class TensorDV, + class Fusion +> +void fmha_bwd_reference( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + TensorDQ mDQ, TensorDK mDK, TensorDV mDV, + Fusion fusion) { + + fmha_bwd_reference_dQ(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion); + fmha_bwd_reference_dK(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion); + fmha_bwd_reference_dV(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp b/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp index 48d8110187..b7c6b412cb 100644 --- a/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp +++ b/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp @@ -128,7 +128,7 @@ void __global__ fmha_reference_kernel( } if (threadIdx.x == 0) { - mLSE(idx_Q + offset_Q, idx_L) = log(sum) + maxS; + mLSE(idx_Q + offset_Q, idx_L) = log(sum) + softmax_scale * maxS; } } diff --git a/examples/77_blackwell_fmha/reference/fmha_mla_reference.hpp b/examples/77_blackwell_fmha/reference/fmha_mla_reference.hpp new file mode 100644 index 0000000000..29db90746e --- /dev/null +++ b/examples/77_blackwell_fmha/reference/fmha_mla_reference.hpp @@ -0,0 +1,206 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorSeq, + class TensorPageTable, + class TensorQL, + class TensorQR, + class TensorCL, + class TensorKR, + class TensorO, + class TensorLSE, + class Scale +> +void __global__ fmha_mla_reference_kernel( + ProblemShape problem_shape, + TensorSeq mSeq, TensorPageTable mPT, + TensorQL mQL, TensorQR mQR, + TensorCL mCL, TensorKR mKR, + TensorO mO, TensorLSE mLSE, + Scale softmax_scale) { + + using namespace cute; + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + using Element = typename TensorO::value_type; + using ElementAcc = typename TensorLSE::value_type; + + extern __shared__ ElementAcc mS[]; + // ElementAcc* mS = reinterpret_cast(mS_mem); + + for (int idx_B = blockIdx.y; idx_B < B; idx_B += gridDim.y) { + if (mSeq.data() != nullptr) { + K = mSeq(idx_B); + } + + for (int idx_H = blockIdx.x; idx_H < H; idx_H += gridDim.x) { + + for (int idx_K = threadIdx.x; idx_K < K; idx_K += blockDim.x) { + ElementAcc acc = 0; + + for (int idx_D = 0; idx_D < D_latent; idx_D++) { + int page_idx_K = idx_K; + int page_idx_B = idx_B; + if (mPT.data() != nullptr) { + page_idx_B = mPT(idx_K / size<0>(mCL), idx_B); + page_idx_K = idx_K % size<0>(mCL); + } + ElementAcc eQ = mQL(idx_H, idx_D, idx_B); + ElementAcc eK = mCL(page_idx_K, idx_D, page_idx_B); + acc += eQ * eK; + } + + for (int idx_D = 0; idx_D < D_rope; idx_D++) { + int page_idx_K = idx_K; + int page_idx_B = idx_B; + if (mPT.data() != nullptr) { + page_idx_B = mPT(idx_K / size<0>(mCL), idx_B); + page_idx_K = idx_K % size<0>(mCL); + } + ElementAcc eQ = mQR(idx_H, idx_D, idx_B); + ElementAcc eK = mKR(page_idx_K, idx_D, page_idx_B); + acc += eQ * eK; + } + mS[idx_K] = acc; + } + + __syncthreads(); + + ElementAcc maxS = -std::numeric_limits::infinity(); + for (int idx_K = 0; idx_K < K; idx_K++) { + maxS = std::max(maxS, mS[idx_K]); + } + if (maxS == -std::numeric_limits::infinity()) maxS = 0; + + __syncthreads(); + +#ifndef B2B + for (int idx_K = threadIdx.x; idx_K < K; idx_K += blockDim.x) { + mS[idx_K] = expf(softmax_scale * (mS[idx_K] - maxS)); + } +#endif + + __syncthreads(); + + ElementAcc sum = 0; + for (int idx_K = 0; idx_K < K; idx_K++) { + sum += mS[idx_K]; + } + + ElementAcc o_scale = 1.0f / sum; +#ifdef B2B + o_scale = 1.0; +#endif + + for (int idx_D = threadIdx.x; idx_D < D_latent; idx_D += blockDim.x) { + ElementAcc acc = 0; + for (int idx_K = 0; idx_K < K; idx_K++) { + int page_idx_K = idx_K; + int page_idx_B = idx_B; + if (mPT.data() != nullptr) { + page_idx_B = mPT(idx_K / size<0>(mCL), idx_B); + page_idx_K = idx_K % size<0>(mCL); + } + ElementAcc eV = mCL(page_idx_K, idx_D, page_idx_B); + ElementAcc eK = static_cast(mS[idx_K]); + acc += eK * eV; + } + mO(idx_H, idx_D, idx_B) = static_cast(acc * o_scale); + } + + if (threadIdx.x == 0) { + mLSE(idx_H, idx_B) = log(sum) + softmax_scale * maxS; + } + + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorSeq, + class TensorPageTable, + class TensorQL, + class TensorQR, + class TensorCL, + class TensorKR, + class TensorO, + class TensorLSE, + class Scale +> +void fmha_mla_reference( + ProblemShape problem_shape, + TensorSeq mSeq, TensorPageTable mPT, + TensorQL mQL, TensorQR mQR, + TensorCL mCL, TensorKR mKR, + TensorO mO, TensorLSE mLSE, + Scale scale) { + + using namespace cute; + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + dim3 grid(H, B, 1); + dim3 block(256); + int shared_mem = K * int(sizeof(typename TensorLSE::value_type)) + 16; + cudaError_t result; + if (shared_mem >= (48 << 10)) { + result = cudaFuncSetAttribute( + &fmha_mla_reference_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + throw std::runtime_error("couldn't perform smem optin"); + } + } + fmha_mla_reference_kernel<<>>( + problem_shape, mSeq, mPT, mQL, mQR, mCL, mKR, mO, mLSE, scale); + cudaDeviceSynchronize(); + result = cudaGetLastError(); + if (cudaSuccess != result) { + throw std::runtime_error("couldn't execute reference"); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/77_blackwell_fmha/reference/reference_abs_error.hpp b/examples/77_blackwell_fmha/reference/reference_abs_error.hpp index e4a01c8216..6d833ad12a 100644 --- a/examples/77_blackwell_fmha/reference/reference_abs_error.hpp +++ b/examples/77_blackwell_fmha/reference/reference_abs_error.hpp @@ -178,3 +178,96 @@ void reference_abs_diff( max_diff = result_host[0]; mean_diff = result_host[1] / static_cast(data.size()); } + +template +__global__ void reference_rel_diff_kernel( + Element* data, Element* data_ref, size_t count, + double* max_diff, double* sum_diff, + bool print_diff ) { + + double thread_max_diff = 0; + double thread_sum_diff = 0; + + __shared__ double block_max_diff; + __shared__ double block_sum_diff; + + for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) { + double diff = fabs(data[i] - data_ref[i]) / fabs(data_ref[i]); + if (print_diff) if (diff != diff || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast(i), diff, (double)data[i], (double)data_ref[i]); + thread_max_diff = fmax(diff, thread_max_diff); + thread_sum_diff += diff; + } + + for (int i = 0; i < blockDim.x; i++) { + if (i == threadIdx.x) { + if (i == 0) { + block_max_diff = thread_max_diff; + block_sum_diff = thread_sum_diff; + } + else { + block_max_diff = fmax(block_max_diff, thread_max_diff); + block_sum_diff += thread_sum_diff; + } + } + __syncthreads(); + } + + if (threadIdx.x == 0) { + atomicAdd(sum_diff, block_sum_diff); + + for (;;) { + unsigned long long prev = *reinterpret_cast(max_diff); + double prev_diff = reinterpret_cast(prev); + double new_max_diff = fmax(block_max_diff, prev_diff); + unsigned long long found = atomicCAS(reinterpret_cast(max_diff), prev, reinterpret_cast(new_max_diff)); + if (found == prev) break; + } + } +} + +template +void reference_rel_diff( + DeviceAllocation const& data, + DeviceAllocation const& data_ref, + double& max_diff, double& mean_diff) { + + static bool kPrintDiff = getenv("REF_PRINT_DIFF") && atoi(getenv("REF_PRINT_DIFF")) == 1; + + DeviceAllocation result; + result.reset(2); + assert(data.size() == data_ref.size()); + + cudaError_t err = cudaMemset(result.get(), 0, result.size() * sizeof(double)); + if (err != cudaSuccess) { + std::cerr << "Memset failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + dim3 block(256, 1, 1); + dim3 grid(1024, 1, 1); + reference_rel_diff_kernel<<>>( + data.get(), data_ref.get(), data.size(), + result.get(), result.get() + 1, kPrintDiff); + + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + std::cerr << "Difference kernel failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + double result_host[2]; + err = cudaMemcpy(result_host, result.get(), result.size() * sizeof(double), cudaMemcpyDefault); + if (err != cudaSuccess) { + std::cerr << "Copy failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + max_diff = result_host[0]; + mean_diff = result_host[1] / static_cast(data.size()); +} diff --git a/examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu b/examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu new file mode 100644 index 0000000000..d36bf4dd74 --- /dev/null +++ b/examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu @@ -0,0 +1,927 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +/*! \file + \brief Grouped GEMM example using CUTLASS 3x APIs for the NVIDIA Blackwell SM120 architecture. + + This example demonstrates an implementation of Grouped GEMM using a TMA + Blackwell SM120 TensorOp-based warp-specialized kernel + for narrow precisions (FP4) with input Scale Factors. + For this example all scheduling work is performed on the device, utilizing the device-side modification of TMA descriptors + to move between groups/problem_count (represented by groups). + https://docs.nvidia.com/cuda/cuda-c-programming-guide/#encoding-a-tensor-map-on-device + + To run this example: + + $ ./examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm --m=2048 --n=2048 --k=2048 --groups=10 + + The above example command makes all 10 groups to be sized at the given m, n, k sizes. + Skipping any of the problem dimensions randomizes it across the different groups. + Same applies for alpha and beta values that are randomized across the different groups. + + To run this example for a set of problems using the benchmark option: + + $ ./examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm --benchmark=./test_benchmark.txt + + Where the test_benchmark.txt may look as such: + 0 256x512x128 + 1 256x512x512 + 2 512x256x128 + 3 256x256x128 + 4 256x512x1024 + 5 1024x512x128 and so on +*/ + +#include +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "helper.h" +using namespace cute; + +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group +using ElementInput = cutlass::float_e2m1_t; // Element type for Input matrix operands + +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +// A matrix configuration +using ElementA = cutlass::nv_float4_t; // Element type for A matrix operand +using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 32; // Alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::nv_float4_t; // Element type for B matrix operand +using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 32; // Alignment of A matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementD = float_e2m1_t; // Element type for D matrix operands +using ElementSFD = cutlass::float_ue4m3_t; // Element type for SF Output operands +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +using LayoutDTag = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +using LayoutSFDTag = LayoutDTag; // Layout type for SFD should be same as D matrix operand + +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Alignment of C matrix in units of elements (up to 16 bytes) +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Alignment of D matrix in units of elements (up to 16 bytes) +// Kernel functional config +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for internal computation +using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Epilogue Operator class tag + +// Kernel Perf config +// Cluster Shape fixed to 1x1x1 +using ThreadBlockShape = Shape<_128,_128,_128>; +using ClusterShape = Shape<_1,_1,_1>; +constexpr int OutputSFVectorSize = 16; + +// D = alpha * acc + beta * C +// With BlockScaleFactor generation. +using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< + OutputSFVectorSize, + ElementD, + ElementCompute, + ElementSFD, LayoutCTag, + ElementC>; + +// Cooperative kernel schedule +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ThreadBlockShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutCTag *, AlignmentC, + ElementD, LayoutCTag *, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation +>::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutATag *, AlignmentA, + ElementB, LayoutBTag *, AlignmentB, + ElementAccumulator, + ThreadBlockShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Auto schedule defaults to cooperative schedule +>::CollectiveOp; +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue +>; +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + +// Pingpong kernel schedule +using CollectiveMainloopPingpong = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutATag *, AlignmentA, + ElementB, LayoutBTag *, AlignmentB, + ElementAccumulator, + ThreadBlockShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong +>::CollectiveOp; + +using GemmKernelPingpong = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloopPingpong, + CollectiveEpilogue +>; + +using GemmPingpong = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideA = typename Gemm::GemmKernel::InternalStrideA; +using StrideB = typename Gemm::GemmKernel::InternalStrideB; +using StrideC = typename Gemm::GemmKernel::InternalStrideC; +using StrideD = typename Gemm::GemmKernel::InternalStrideD; + +using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA; +using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB; +using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; +using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig< + OutputSFVectorSize, + cute::is_same_v ? cute::UMMA::Major::K : cute::UMMA::Major::MN + >; +using OutputSFAtom = typename Sm1xxBlockScaledOutputConfig::SfAtom; +using LayoutSFD = typename Sm1xxBlockScaledOutputConfig::LayoutSF; + +// Host-side allocations +std::vector stride_A_host; +std::vector stride_B_host; +std::vector layout_SFA_host; +std::vector layout_SFB_host; +std::vector stride_C_host; +std::vector stride_D_host; + +std::vector alpha_host; +std::vector beta_host; + +using HostTensorA = cutlass::HostTensor; +using HostTensorB = cutlass::HostTensor; +using HostTensorSF = cutlass::HostTensor; +using HostTensorC = cutlass::HostTensor; +using HostTensorD = cutlass::HostTensor; +std::vector block_A; +std::vector block_B; +std::vector block_SFA; +std::vector block_SFB; +std::vector block_C; +std::vector block_D; +std::vector block_SFD; +std::vector block_ref_D; +std::vector block_ref_SFD; + +// Device-side allocations +cutlass::DeviceAllocation problem_sizes; + +cutlass::DeviceAllocation ptr_A; +cutlass::DeviceAllocation ptr_B; +cutlass::DeviceAllocation ptr_SFA; +cutlass::DeviceAllocation ptr_SFB; +cutlass::DeviceAllocation ptr_C; +cutlass::DeviceAllocation ptr_D; +cutlass::DeviceAllocation ptr_SFD; +cutlass::DeviceAllocation ptr_ref_D; + +cutlass::DeviceAllocation stride_A; +cutlass::DeviceAllocation stride_B; +cutlass::DeviceAllocation layout_SFA; +cutlass::DeviceAllocation layout_SFB; +cutlass::DeviceAllocation stride_C; +cutlass::DeviceAllocation stride_D; + +// Note, this is an array of pointers to alpha and beta scaling values per group +cutlass::DeviceAllocation alpha_device; +cutlass::DeviceAllocation beta_device; +cutlass::DeviceAllocation block_alpha; +cutlass::DeviceAllocation block_beta; +// A matrix wide constant value to scale the output matrix +// Avoids generating small FP4 values. +// NormConst is a single device-side constant value, its not per-batch or per-group +cutlass::DeviceAllocation norm_constant_device; + +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + +template +auto make_iterator(T* ptr) { + using namespace cute; + if constexpr (cute::is_subbyte_v) { + return subbyte_iterator(ptr); + } + else { + return ptr; + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams::RasterOrderOptions; +// Command line options parsing +struct Options { + + bool help = false; + bool verification = true; + bool use_pdl = false; + + float alpha = std::numeric_limits::max(); + float beta = std::numeric_limits::max(); + float norm_constant = 1.0; + int iterations = 10; + int m = 1024, n = 2048, k = 512, groups = 10; + RasterOrderOptions raster_order = RasterOrderOptions::AlongN; + int max_sm_count = INT_MAX; + std::string benchmark_path; + std::vector problem_sizes_host; + int const tma_alignment_bits = 128; + int const alignment = tma_alignment_bits / cutlass::sizeof_bits::value; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + if (cmd.check_cmd_line_flag("no_verif")) { + verification = false; + } + if (cmd.check_cmd_line_flag("use_pdl")) { + use_pdl = true; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("groups", groups); + cmd.get_cmd_line_argument("alpha", alpha, std::numeric_limits::max()); + cmd.get_cmd_line_argument("beta", beta, std::numeric_limits::max()); + cmd.get_cmd_line_argument("norm_constant", norm_constant, float(1.0)); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("benchmark", benchmark_path); + cmd.get_cmd_line_argument("max_sm_count", max_sm_count, INT_MAX); + + // Decide how to initialize the problems + if (!benchmark_path.empty()) { + if (!benchmark_problems()) { + problem_sizes_host.clear(); + return; + } + } + else { + randomize_problems(cmd); + } + + char raster_char; + cmd.get_cmd_line_argument("raster", raster_char); + + if (raster_char == 'N' || raster_char == 'n') { + raster_order = RasterOrderOptions::AlongN; + } + else if (raster_char == 'M' || raster_char == 'm') { + raster_order = RasterOrderOptions::AlongM; + } + } + + void randomize_problems(cutlass::CommandLine &cmd) { + int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1; + cmd.get_cmd_line_argument("m", cmd_line_m); + cmd.get_cmd_line_argument("n", cmd_line_n); + cmd.get_cmd_line_argument("k", cmd_line_k); + + problem_sizes_host.reserve(groups); + + for (int i = groups; i > 0; i--) { + int m = cmd_line_m; + int n = cmd_line_n; + int k = cmd_line_k; + if (m < 1) { + m = alignment * ((rand() % 64) + 1); + } + if (n < 1) { + n = alignment * ((rand() % 64) + 1); + } + if (k < 1) { + k = alignment * ((rand() % 64) + 1); + } + problem_sizes_host.push_back({m, n, k}); + } + } + + /// Load a benchmark + bool benchmark_problems() { + std::ifstream file(benchmark_path); + if (!file.good()) { + return false; + } + + while (file.good()) { + + int idx = -1; + std::string extent_str; + + file >> idx >> extent_str; + + if (idx < 0 || extent_str.empty()) { + break; + } + + cutlass::gemm::GemmCoord extent; + std::vector tokens; + + cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); + + for (int i = 0; i < int(tokens.size()); ++i) { + int x = std::atoi(tokens.at(i).c_str()); + + // round up + if (x % alignment) { + x += (alignment - (x % alignment)); + } + + extent.at(i) = x; + } + + if (extent.product()) { + problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()}); + } + } + groups = static_cast(problem_sizes_host.size()); + + return true; + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "79d_blackwell_geforce_nvfp4_grouped_gemm\n\n" + << " Blackwell Block Scaled Narrow Precision Grouped GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM for all groups\n" + << " --n= Sets the N extent of the GEMM for all groups\n" + << " --k= Sets the K extent of the GEMM for all groups\n" + << " --groups= Sets the number of individual GEMM problems for Grouped GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --norm_constant= Epilogue scalar normalization constant for the output matrix\n\n" + << " --raster= CTA Rasterization direction (N for along N, M for along M)\n\n" + << " --iterations= Number of profiling iterations to perform\n\n" + << " --benchmark= Executes a benchmark problem size\n" + << " --max_sm_count= Run kernels using only these number of SMs\n" + << " --no_verif Do not run (host-side) verification kernels\n" + << " --use_pdl Launch kernel with PDL (Programmatic Dependent Launch) enabled\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "79d_blackwell_geforce_nvfp4_grouped_gemm" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s, std::vector problem_sizes_host) const + { + // Number of real-valued multiply-adds + uint64_t fmas = uint64_t(); + + for (auto const & problem : problem_sizes_host) { + fmas += static_cast(get<0>(problem)) * + static_cast(get<1>(problem)) * + static_cast(get<2>(problem)); + } + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * uint64_t(fmas); + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms = 0.0; + double gflops = 0.0; + cutlass::Status status = cutlass::Status::kSuccess; + cudaError_t error = cudaSuccess; + bool passed = false; +}; + +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::TensorView view, + uint64_t seed) { + + double scope_max, scope_min; + constexpr int bits_input = cutlass::sizeof_bits::value; + + if constexpr (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if constexpr (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + else if constexpr (bits_input <= 8) { + if constexpr (cute::is_same_v) { + scope_max = 4; + scope_min = 1; + } + else { + scope_max = 1; + scope_min = -1; + } + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + + return true; +} + +/// Allocates device-side data +void allocate(const Options &options) { + for (int32_t i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}); + auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}); + + auto layout_A = make_layout(make_shape(M, K, 1), stride_A); + auto layout_B = make_layout(make_shape(N, K, 1), stride_B); + auto layout_C = make_layout(make_shape(M, N, 1), stride_C); + auto layout_D = make_layout(make_shape(M, N, 1), stride_D); + auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + auto layout_SFD = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1)); + + stride_A_host.push_back(stride_A); + stride_B_host.push_back(stride_B); + layout_SFA_host.push_back(layout_SFA); + layout_SFB_host.push_back(layout_SFB); + stride_C_host.push_back(stride_C); + stride_D_host.push_back(stride_D); + + block_A.push_back(HostTensorA(cutlass::make_Coord(size(layout_A)))); + block_B.push_back(HostTensorB(cutlass::make_Coord(size(layout_B)))); + block_SFA.push_back(HostTensorSF(cutlass::make_Coord(size(filter_zeros(layout_SFA))))); + block_SFB.push_back(HostTensorSF(cutlass::make_Coord(size(filter_zeros(layout_SFB))))); + block_C.push_back(HostTensorC(cutlass::make_Coord(size(layout_C)))); + block_D.push_back(HostTensorD(cutlass::make_Coord(size(layout_D)))); + block_SFD.push_back(HostTensorSF(cutlass::make_Coord(size(filter_zeros(layout_SFD))))); + block_ref_D.push_back(HostTensorD(cutlass::make_Coord(size(layout_D)))); + block_ref_SFD.push_back(HostTensorSF(cutlass::make_Coord(size(filter_zeros(layout_SFD))))); + } + block_alpha.reset(options.groups); + block_beta.reset(options.groups); +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + uint64_t seed = 2020; + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_host.data()); + + // + // Assign pointers + // + + std::vector ptr_A_host(options.groups); + std::vector ptr_B_host(options.groups); + std::vector ptr_SFA_host(options.groups); + std::vector ptr_SFB_host(options.groups); + std::vector ptr_C_host(options.groups); + std::vector ptr_D_host(options.groups); + std::vector ptr_SFD_host(options.groups); + std::vector ptr_alpha_host(options.groups); + std::vector ptr_beta_host(options.groups); + + for (int32_t i = 0; i < options.groups; ++i) { + + initialize_block(block_A.at(i).host_view(), seed + 2021); + initialize_block(block_B.at(i).host_view(), seed + 2022); + initialize_block(block_C.at(i).host_view(), seed + 2023); + initialize_block(block_SFA.at(i).host_view(), seed + 2024); + initialize_block(block_SFB.at(i).host_view(), seed + 2025); + + block_A.at(i).sync_device(); + block_B.at(i).sync_device(); + block_C.at(i).sync_device(); + block_SFA.at(i).sync_device(); + block_SFB.at(i).sync_device(); + + ptr_A_host.at(i) = block_A.at(i).device_data(); + ptr_B_host.at(i) = block_B.at(i).device_data(); + ptr_SFA_host.at(i) = block_SFA.at(i).device_data(); + ptr_SFB_host.at(i) = block_SFB.at(i).device_data(); + ptr_C_host.at(i) = block_C.at(i).device_data(); + ptr_D_host.at(i) = block_D.at(i).device_data(); + ptr_SFD_host.at(i) = block_SFD.at(i).device_data(); + + alpha_host.push_back((options.alpha == std::numeric_limits::max()) ? static_cast((rand() % 5) + 1) : options.alpha); + beta_host.push_back((options.beta == std::numeric_limits::max()) ? static_cast(rand() % 5) : options.beta); + ptr_alpha_host.at(i) = block_alpha.get() + i; + ptr_beta_host.at(i) = block_beta.get() + i; + } + + ptr_A.reset(options.groups); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_SFA.reset(options.groups); + ptr_SFA.copy_from_host(ptr_SFA_host.data()); + + ptr_SFB.reset(options.groups); + ptr_SFB.copy_from_host(ptr_SFB_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(ptr_D_host.data()); + + ptr_SFD.reset(options.groups); + ptr_SFD.copy_from_host(ptr_SFD_host.data()); + + stride_A.reset(options.groups); + stride_A.copy_from_host(stride_A_host.data()); + + stride_B.reset(options.groups); + stride_B.copy_from_host(stride_B_host.data()); + + layout_SFA.reset(options.groups); + layout_SFA.copy_from_host(layout_SFA_host.data()); + + layout_SFB.reset(options.groups); + layout_SFB.copy_from_host(layout_SFB_host.data()); + + stride_C.reset(options.groups); + stride_C.copy_from_host(stride_C_host.data()); + + stride_D.reset(options.groups); + stride_D.copy_from_host(stride_D_host.data()); + + alpha_device.reset(options.groups); + alpha_device.copy_from_host(ptr_alpha_host.data()); + beta_device.reset(options.groups); + beta_device.copy_from_host(ptr_beta_host.data()); + + block_alpha.copy_from_host(alpha_host.data()); + block_beta.copy_from_host(beta_host.data()); + + norm_constant_device.reset(1); + norm_constant_device.copy_from_host(&options.norm_constant); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +template +typename Gemm::Arguments args_from_options(Options &options, bool host_problem_shapes_available = true) +{ + cutlass::KernelHardwareInfo hw_info; + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = min(cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id), options.max_sm_count); + + typename Gemm::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + // If alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches. + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups. + if (options.alpha != std::numeric_limits::max()){ + // Single alpha for all groups + fusion_args.alpha = options.alpha; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.dAlpha = {_0{}, _0{}, 0}; + } + else { + fusion_args.alpha = 0; + fusion_args.alpha_ptr_array = alpha_device.get(); + // Only one alpha per each group + fusion_args.dAlpha = {_0{}, _0{}, 1}; + } + if (options.beta != std::numeric_limits::max()) { + // Single beta for all groups + fusion_args.beta = options.beta; + fusion_args.beta_ptr_array = nullptr; + fusion_args.dBeta = {_0{}, _0{}, 0}; + } + else { + fusion_args.beta = 0; + fusion_args.beta_ptr_array = beta_device.get(); + // Only one beta per each group + fusion_args.dBeta = {_0{}, _0{}, 1}; + } + + // Output Block SF + fusion_args.block_scale_factor_ptr = ptr_SFD.get(); // Enable for SF Output + fusion_args.norm_constant_ptr = norm_constant_device.get(); // Enable for SF Output + + typename Gemm::GemmKernel::TileSchedulerArguments scheduler; + scheduler.raster_order = options.raster_order; + + if (host_problem_shapes_available) { + arguments = typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(), + ptr_SFA.get(), layout_SFA.get(), ptr_SFB.get(), layout_SFB.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info, scheduler + }; + } + else { + arguments = typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), nullptr}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(), + ptr_SFA.get(), layout_SFA.get(), ptr_SFB.get(), layout_SFB.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info, scheduler + }; + } + + return arguments; +} + +bool verify(const Options &options) { + using namespace cute; + bool passed = true; + for (int32_t i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}); + auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}); + auto layout_A = make_layout(make_shape(M, K, 1), stride_A); + auto layout_B = make_layout(make_shape(N, K, 1), stride_B); + auto layout_C = make_layout(make_shape(M, N, 1), stride_C); + auto layout_D = make_layout(make_shape(M, N, 1), stride_D); + auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + auto layout_SFD = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1)); + + // Create the arguments for host reference implementation + Tensor tensor_A = make_tensor(make_iterator(block_A.at(i).host_data()), layout_A); + Tensor tensor_SFA = make_tensor(block_SFA.at(i).host_data(), layout_SFA); + Tensor tensor_B = make_tensor(make_iterator(block_B.at(i).host_data()), layout_B); + Tensor tensor_SFB = make_tensor(block_SFB.at(i).host_data(), layout_SFB); + cutlass::reference::host::GettBlockScalingMainloopParams + mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB}; + + auto tensor_C = cute::make_tensor(make_iterator(block_C.at(i).host_data()), layout_C); + auto tensor_ref_D = cute::make_tensor(make_iterator(block_ref_D.at(i).host_data()), layout_D); + auto tensor_ref_SFD = cute::make_tensor(make_iterator(block_ref_SFD.at(i).host_data()), layout_SFD); + + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementCompute, // ElementScalar + ElementAccumulator, // ElementAccumulator + ElementCompute, // ElementCompute + decltype(tensor_C), // TensorC + decltype(tensor_ref_D), // TensorD + decltype(tensor_ref_SFD), // TensorSfD + cute::Int, + cutlass::reference::host::SfStrategy::SfDGen + > epilogue_params {alpha_host.at(i), beta_host.at(i), tensor_C, tensor_ref_D, tensor_ref_SFD, options.norm_constant}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // Comparison + block_D.at(i).sync_host(); + block_SFD.at(i).sync_host(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + passed &= cutlass::reference::host::TensorEquals(block_ref_D.at(i).host_view(), block_D.at(i).host_view()); + passed &= cutlass::reference::host::TensorEquals(block_ref_SFD.at(i).host_view(), block_SFD.at(i).host_view()); + // Check that the tensors have non-zero norms + passed &= (cutlass::reference::host::TensorNorm(block_ref_D.at(i).host_view()) > 0); + passed &= (cutlass::reference::host::TensorNorm(block_D.at(i).host_view()) > 0); + passed &= (cutlass::reference::host::TensorNorm(block_ref_SFD.at(i).host_view()) > 0); + passed &= (cutlass::reference::host::TensorNorm(block_SFD.at(i).host_view()) > 0); + } + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options, bool host_problem_shapes_available = true) +{ + std::cout << " Problem Sizes, Alpha, Beta " << std::endl; + for (int32_t i = 0; i < options.groups; ++i) { + std::cout << " " << options.problem_sizes_host.at(i); + std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl; + } + std::cout << " Groups : " << options.groups << std::endl; + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options, host_problem_shapes_available); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl)); + + cudaDeviceSynchronize(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + if (options.verification) { + std::cout << " Host-side verification is now running - may be very slow for large cases." << std::endl; + result.passed = verify(options); + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + if (!result.passed) { + exit(-1); + } + } + else { + std::cout << " Verfication is turned off for this run." << std::endl; + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl)); + } + timer.stop(); + + // Compute average setup and runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host); + + std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS : " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example + if (__CUDACC_VER_MAJOR__ < 12 || + ((__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8) + ) + ) { + std::cerr << "This example requires CUDA 12.8 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (!(props.major == 12 && props.minor == 0)) { + std::cerr + << "This example requires a GPU of NVIDIA's Blackwell Architecture (compute capability 120a).\n"; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + allocate(options); + initialize(options); + + // + // Evaluate CUTLASS kernels + // + + std::cout << "Running kernel with Cooperative kernel schedule:" << std::endl; + run(options, false /*host_problem_shapes_available*/); + std::cout << "Running kernel with Pingpong kernel schedule:" << std::endl; + run(options, false /*host_problem_shapes_available*/); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/79_blackwell_geforce_gemm/CMakeLists.txt b/examples/79_blackwell_geforce_gemm/CMakeLists.txt index cb7e3e97c0..b689c85e7e 100644 --- a/examples/79_blackwell_geforce_gemm/CMakeLists.txt +++ b/examples/79_blackwell_geforce_gemm/CMakeLists.txt @@ -28,6 +28,24 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +set(TEST_RANDOM --iterations=0) # Random problem sizes +set(TEST_RANDOM_LARGE_GROUP --groups=50 --iterations=0) # Random problem sizes + +set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes +set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=50 --iterations=0) # Random problem sizes + +set(TEST_EPILOGUE_OP --beta=0.5 --iterations=1) # Random problem sizes +set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=1.5 --iterations=1) # Random problem sizes + +set(TEST_FIXED --m=2048 --n=5120 --k=8192 --iterations=0) # Fixed problem sizes +set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=51 --iterations=0) # Fixed problem sizes + +set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes +set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=50 --iterations=0) # Small problem sizes + +set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes +set(TEST_RANDOM_PERF_LARGE_GROUP --groups=50 --iterations=10) # Random problem sizes + if (CUTLASS_NVCC_ARCHS MATCHES 120a) cutlass_example_add_executable( 79a_blackwell_geforce_nvfp4_bf16_gemm @@ -44,4 +62,22 @@ cutlass_example_add_executable( 79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu ) +cutlass_example_add_executable( + 79d_blackwell_geforce_nvfp4_grouped_gemm + 79d_blackwell_geforce_nvfp4_grouped_gemm.cu + TEST_COMMAND_OPTIONS + TEST_RANDOM + TEST_RANDOM_LARGE_GROUP + TEST_EPILOGUE + TEST_EPILOGUE_LARGE_GROUP + TEST_EPILOGUE_OP + TEST_EPILOGUE_OP_LARGE_GROUP + TEST_FIXED + TEST_FIXED_LARGE_GROUP + TEST_SMALL + TEST_SMALL_LARGE_GROUP + TEST_RANDOM_PERF + TEST_RANDOM_PERF_LARGE_GROUP +) + endif() diff --git a/examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu b/examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu new file mode 100644 index 0000000000..32df1146ae --- /dev/null +++ b/examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu @@ -0,0 +1,554 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM120 architecture. + + This example demonstrates a simple way to instantiate and run a narrow precision blockscaled sparse GEMM on the NVIDIA Blackwell SM120 architecture. + This kernel is optimized for the GeForce RTX 50 series GPUs. + + The Blackwell SM120 CUTLASS kernel uses the new Block Scaled Sparse Tensor Core MMA Instructions: + * mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale. + Please see more detail in https://docs.nvidia.com/cuda/parallel-thread-execution. + + The kernel leverages: + 1. Warp-Specialized persistent kernel design that supports cooperative scheduler introduced in Hopper. + 2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution). + 3. Block Scaled Sparse Tensor Core MMA Instructions + + Note that GeForce RTX 50 series GPUs do not support: + 1. Multicast feature of TMA load. Cluster shape has to be 1x1x1. + 2. Dynamic datatypes. + + Usage: + $ ./examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm --m=2048 --n=2048 --k=2048 +*/ +#include +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" +#include "cutlass/transform/device/transform_universal_adapter.hpp" + +#include "helper.h" +using namespace cute; +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +// A matrix configuration +using ElementA = cutlass::mx_float8_t; // Element type for A matrix operand +using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 32; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::mx_float8_t; // Element type for B matrix operand +using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 16; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C/D matrix configuration +using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand +using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand +using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand +using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// E matrix configuration. Note, E is used to represent metadata tensor. +using ElementE = uint8_t; // Element type for E matrix operand +// Kernel functional config +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassBlockScaledSparseTensorOp; // Operator class tag +using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecializedMxf8f6f4Acc2x4Sm120; // Kernel schedule policy +// Kernel Perf config +using ThreadBlockShape = Shape<_128,_128,_256>; // Threadblock's tile size +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ThreadBlockShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutATag, AlignmentA, + ElementB, LayoutBTag, AlignmentB, + ElementAccumulator, + ThreadBlockShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelScheduleType // Mainloop schedule policy + >::CollectiveOp; +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + void>; +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +// Reference device GEMM implementation type +using StrideA = typename Gemm::GemmKernel::StrideA; +using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA; +using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{})); +using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{})); +using StrideD = typename Gemm::GemmKernel::StrideD; +using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{})); +using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE; +// +// Data members +// +/// Initialization +StrideA stride_A; +LayoutA layout_A; +LayoutSFA layout_SFA; +StrideB stride_B; +LayoutB layout_B; +LayoutSFB layout_SFB; +StrideC stride_C; +LayoutC layout_C; +StrideD stride_D; +LayoutD layout_D; +LayoutE layout_E; +uint64_t seed; +// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device +// Use cute::Tensor and cute::Layout for iterating thru the matrix elements +cutlass::HostTensor block_A; +cutlass::HostTensor block_A_Decompressed; +cutlass::HostTensor block_E; +cutlass::HostTensor block_SFA; +cutlass::HostTensor block_B; +cutlass::HostTensor block_SFB; +cutlass::HostTensor block_C; +// Output Tensor +cutlass::HostTensor block_D; +// Reference Output Tensor +cutlass::HostTensor block_reference_D; +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +template +auto make_iterator(T* ptr) { + using namespace cute; + if constexpr (cute::is_subbyte_v) { + return subbyte_iterator(ptr); + } + else { + return ptr; + } +} +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// +// Command line options parsing +struct Options { + bool help; + float alpha, beta; + int iterations; + int m, n, k; + Options(): + help(false), + m(1024), n(1024), k(1024), + alpha(1.f), beta(0.f), + iterations(10) + { } + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + out << "80a_blackwell_geforce_mxfp8_bf16_sparse_gemm\n\n" + << " Blackwell MXFP8 Sparse GEMM is a warp specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + out << "\n\nExamples:\n\n" + << "$ " << "./examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + return out; + } + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} +}; +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::TensorView view, + uint64_t seed) { + double scope_max, scope_min; + constexpr int bits_input = cutlass::sizeof_bits::value; + if constexpr (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if constexpr (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + else if constexpr (bits_input <= 8) { + if constexpr (cute::is_same_v) { + scope_max = 4; + scope_min = 1; + } + else { + scope_max = 1; + scope_min = -1; + } + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + + return true; +} +/// Initialize blocks that released to sparse Matrix A and its metadata E +bool initialize_sparse_blocks(const Options &options) { + auto workload = make_shape(options.m, + options.n, + options.k, + 1); + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); + /// Alias SparseConfig and Compressor + using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig; + using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + cute::Shape, + ElementA::DataType, + LayoutATag, + SparseConfig>; + using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + cute::Shape, + ElementA::DataType, + LayoutATag, + SparseConfig, + cutlass::arch::Sm120>; + using Compressor = cutlass::transform::device::TransformUniversalAdapter; + /// Declare compressor_utility to randomly fill zero in Matrix A to match sparsity needs + CompressorUtility compressor_utility(workload, stride_A); + // Aligned M K dimension size for A and E + int aligned_m_e = compressor_utility.get_metadata_m_physical(); + int aligned_k_e = compressor_utility.get_metadata_k_physical(); + int aligned_m_a = compressor_utility.get_tensorA_m_physical(); + int aligned_k_a = compressor_utility.get_tensorA_k_physical(); + /// Layout A and E + layout_A = SparseConfig::fill_layoutA(workload); + layout_E = SparseConfig::fill_layoutE(workload); + + block_A.reset(cutlass::make_Coord(aligned_m_a * aligned_k_a)); + block_E.reset(cutlass::make_Coord(aligned_m_e * aligned_k_e)); + block_A_Decompressed.reset(cutlass::make_Coord(options.m * options.k)); + initialize_block(block_A_Decompressed.host_view(), seed + 2020); + compressor_utility.structure_sparse_zero_mask_fill( + block_A_Decompressed.host_data(), static_cast(seed + 2021)); + block_A_Decompressed.sync_device(); + + /// Use compressor kernel to generate compressed Matrix A and E + cutlass::Status status { cutlass::Status::kSuccess }; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments{ + {options.m, options.n, options.k, 1}, + {block_A_Decompressed.device_data(), + stride_A, + block_A.device_data(), + block_E.device_data()}, + {hw_info} + }; + + // Compress A and E + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + status = compressor_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.run(); + auto result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + return false; + } + + block_A.sync_host(); + block_E.sync_host(); + return true; +} +/// Initialize operands to be used in the GEMM and reference GEMM +bool initialize(const Options &options) { + using namespace cute; + + // Initial A, E(metadata) and A_compressed blocks + if(!initialize_sparse_blocks(options)) return false; + + // Define B, C and D blocks + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1}); + layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B); + layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C); + layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D); + // Define SFA and SFB tensors layouts + layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1)); + layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1)); + block_B.reset(cutlass::make_Coord(size(layout_B))); + block_C.reset(cutlass::make_Coord(size(layout_C))); + block_D.reset(cutlass::make_Coord(size(layout_D))); + block_reference_D.reset(cutlass::make_Coord(size(layout_D))); + block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA)))); + block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB)))); + initialize_block(block_B.host_view(), seed + 2022); + initialize_block(block_C.host_view(), seed + 2023); + initialize_block(block_SFA.host_view(), seed + 2024); + initialize_block(block_SFB.host_view(), seed + 2025); + block_B.sync_device(); + block_C.sync_device(); + block_SFA.sync_device(); + block_SFB.sync_device(); + return true; +} +// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, 1}, + { // Mainloop arguments + block_A.device_data(), layout_A, + block_B.device_data(), stride_B, + block_E.device_data(), layout_E, + block_SFA.device_data(), layout_SFA, + block_SFB.device_data(), layout_SFB + }, + { // Epilogue arguments + {options.alpha, options.beta}, + block_C.device_data(), stride_C, + block_D.device_data(), stride_D + } + }; + return arguments; +} +bool verify(const Options &options) { + using namespace cute; + // Create the arguments for host reference implementation + Tensor tensor_A = make_tensor(make_iterator(block_A_Decompressed.host_data()), layout_A); + Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA); + Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B); + Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB); + Tensor tensor_E = make_tensor(make_iterator(block_E.host_data()), layout_E); + + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, // ElementAccumulator + decltype(tensor_A), // TensorA + decltype(tensor_SFA), // TensorSfA + decltype(tensor_B), // TensorB + decltype(tensor_SFB) // TensorSfB + > mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB}; + auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C); + auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D); + + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementAccumulator, // ElementScalar + ElementAccumulator, // ElementAccumulator + ElementAccumulator, // ElementCompute + decltype(tensor_C), // TensorC + decltype(tensor_D) // TensorD + > epilogue_params{options.alpha, options.beta, tensor_C, tensor_D}; + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + // Comparison + block_D.sync_host(); + + bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_reference_D.host_view()); + passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0); + passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0); + return passed; +} +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + // Initialization + if(!initialize(options)) + { + std::cerr << " Initialization failed! " << std::endl; + exit(-1); + } + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + cudaDeviceSynchronize(); + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + if (!result.passed) { + exit(-1); + } + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + return 0; +} +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +/////////////////////////////////////////////////////////////////////////////////////////////////// +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example + // and must have compute capability at least 120. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + + if (!(props.major == 12 && props.minor == 0)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl; + return 0; + } + // + // Parse options + // + Options options; + options.parse(argc, args); + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + // + // Evaluate CUTLASS kernels + // +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + return 0; +} +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu b/examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu new file mode 100644 index 0000000000..f3441b5630 --- /dev/null +++ b/examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu @@ -0,0 +1,578 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM120 architecture. + + This example demonstrates a simple way to instantiate and run a narrow precision blockscaled sparse GEMM on the NVIDIA Blackwell SM120 architecture. + This kernel is optimized for the GeForce RTX 50 series GPUs. + + The Blackwell SM120 CUTLASS kernel uses the new Block Scaled Sparse Tensor Core MMA Instructions: + * mma.sync.aligned.kind::mxf4nvf4.sp::ordered_metadata.block_scale. + Please see more detail in https://docs.nvidia.com/cuda/parallel-thread-execution. + + The kernel leverages: + 1. Warp-Specialized persistent kernel design that supports cooperative scheduler introduced in Hopper. + 2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution). + 3. Block Scaled Sparse Tensor Core MMA Instructions + + Note that GeForce RTX 50 series GPUs do not support: + 1. Multicast feature of TMA load. Cluster shape has to be 1x1x1. + 2. Dynamic datatypes. + + Usage: + $ ./examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm --m=2048 --n=2048 --k=2048 +*/ +#include +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" +#include "cutlass/transform/device/transform_universal_adapter.hpp" + +#include "helper.h" +using namespace cute; +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +// A matrix configuration +using ElementA = cutlass::nv_float4_t; // Element type for A matrix operand +using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 64; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::nv_float4_t; // Element type for B matrix operand +using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C/D matrix configuration +using ElementD = cutlass::float_e2m1_t; // Element type for D matrix operand +using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand +using LayoutCTag = cutlass::layout::ColumnMajor; // Layout type for C matrix operand +using LayoutDTag = cutlass::layout::ColumnMajor; // Layout type for D matrix operand +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +constexpr int outputVectorSize = 32; // Vector size for D matrix +using outputScaleFactor = cutlass::float_ue4m3_t; // Scale factor type for D matrix +// E matrix configuration. Note, E is used to represent metadata tensor. +using ElementE = uint8_t; // Element type for E matrix operand +// Kernel functional config +using ElementCompute = float; // Element type for computation inside mainloop and epilogue +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassBlockScaledSparseTensorOp; // Operator class tag +using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecializedNvf4Sm120; // Kernel schedule policy +// Kernel Perf config +using ThreadBlockShape = Shape<_128,_128,_256>; // Threadblock's tile size +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ThreadBlockShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + cutlass::epilogue::SparseTmaWarpSpecializedCooperativeSm120, // Epilogue schedule policy + cutlass::epilogue::fusion::LinCombBlockScaleFactor< // Epilogue fusion to generate nvfp4 output + outputVectorSize, ElementD, ElementAccumulator, outputScaleFactor, LayoutDTag, ElementC> + >::CollectiveOp; +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutATag, AlignmentA, + ElementB, LayoutBTag, AlignmentB, + ElementAccumulator, + ThreadBlockShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelScheduleType // Mainloop schedule policy + >::CollectiveOp; +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + void>; +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +// Reference device GEMM implementation type +using StrideA = typename Gemm::GemmKernel::StrideA; +using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA; +using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{})); +using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{})); +using StrideD = typename Gemm::GemmKernel::StrideD; +using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{})); +using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE; +using SfdOutputCfg = cutlass::detail::Sm1xxBlockScaledOutputConfig; +using LayoutSFD = typename SfdOutputCfg::LayoutSF; +// +// Data members +// +/// Initialization +StrideA stride_A; +LayoutA layout_A; +LayoutSFA layout_SFA; +StrideB stride_B; +LayoutB layout_B; +LayoutSFB layout_SFB; +StrideC stride_C; +LayoutC layout_C; +StrideD stride_D; +LayoutD layout_D; +LayoutSFD layout_SFD; +LayoutE layout_E; +uint64_t seed; +// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device +// Use cute::Tensor and cute::Layout for iterating thru the matrix elements +cutlass::HostTensor block_A; +cutlass::HostTensor block_A_Decompressed; +cutlass::HostTensor block_E; +cutlass::HostTensor block_SFA; +cutlass::HostTensor block_B; +cutlass::HostTensor block_SFB; +cutlass::HostTensor block_C; +// Output Tensor +cutlass::HostTensor block_D; +cutlass::HostTensor block_SFD; +// Reference Output Tensor +cutlass::HostTensor block_reference_D; +cutlass::HostTensor block_reference_SFD; +cutlass::HostTensor block_Normconst; +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +template +auto make_iterator(T* ptr) { + using namespace cute; + if constexpr (cute::is_subbyte_v) { + return subbyte_iterator(ptr); + } + else { + return ptr; + } +} +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// +// Command line options parsing +struct Options { + bool help; + float alpha, beta; + int iterations; + int m, n, k; + Options(): + help(false), + m(1024), n(1024), k(1024), + alpha(1.f), beta(0.f), + iterations(10) + { } + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + out << "80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm\n\n" + << " Blackwell MXFP8 Sparse GEMM is a warp specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + out << "\n\nExamples:\n\n" + << "$ " << "./examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + return out; + } + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} +}; +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::TensorView view, + uint64_t seed) { + double scope_max, scope_min; + constexpr int bits_input = cutlass::sizeof_bits::value; + if constexpr (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if constexpr (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + else if constexpr (bits_input <= 8) { + if constexpr (cute::is_same_v) { + scope_max = 4; + scope_min = 1; + } + else { + scope_max = 1; + scope_min = -1; + } + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + + return true; +} +/// Initialize blocks that released to sparse Matrix A and its metadata E +bool initialize_sparse_blocks(const Options &options) { + auto workload = make_shape(options.m, + options.n, + options.k, + 1); + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); + /// Alias SparseConfig and Compressor + using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig; + using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + cute::Shape, + ElementA::DataType, + LayoutATag, + SparseConfig>; + using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + cute::Shape, + ElementA::DataType, + LayoutATag, + SparseConfig, + cutlass::arch::Sm120>; + using Compressor = cutlass::transform::device::TransformUniversalAdapter; + /// Declare compressor_utility to randomly fill zero in Matrix A to match sparsity needs + CompressorUtility compressor_utility(workload, stride_A); + // Aligned M K dimension size for A and E + int aligned_m_e = compressor_utility.get_metadata_m_physical(); + int aligned_k_e = compressor_utility.get_metadata_k_physical(); + int aligned_m_a = compressor_utility.get_tensorA_m_physical(); + int aligned_k_a = compressor_utility.get_tensorA_k_physical(); + /// Layout A and E + layout_A = SparseConfig::fill_layoutA(workload); + layout_E = SparseConfig::fill_layoutE(workload); + + block_A.reset(cutlass::make_Coord(aligned_m_a * aligned_k_a)); + block_E.reset(cutlass::make_Coord(aligned_m_e * aligned_k_e)); + block_A_Decompressed.reset(cutlass::make_Coord(options.m * options.k)); + initialize_block(block_A_Decompressed.host_view(), seed + 2020); + compressor_utility.structure_sparse_zero_mask_fill( + block_A_Decompressed.host_data(), static_cast(seed + 2021)); + block_A_Decompressed.sync_device(); + + /// Use compressor kernel to generate compressed Matrix A and E + cutlass::Status status { cutlass::Status::kSuccess }; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments{ + {options.m, options.n, options.k, 1}, + {block_A_Decompressed.device_data(), + stride_A, + block_A.device_data(), + block_E.device_data()}, + {hw_info} + }; + + // Compress A and E + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + status = compressor_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.run(); + auto result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + return false; + } + + block_A.sync_host(); + block_E.sync_host(); + return true; +} +/// Initialize operands to be used in the GEMM and reference GEMM +bool initialize(const Options &options) { + using namespace cute; + + // Initial A, E(metadata) and A_compressed blocks + if(!initialize_sparse_blocks(options)) return false; + + // Define B, C and D blocks + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1}); + layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B); + layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C); + layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D); + layout_SFD = SfdOutputCfg::tile_atom_to_shape_SFD(cute::make_shape(options.m, options.n, options.k, 1)); + // Define SFA and SFB tensors layouts + layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1)); + layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1)); + block_B.reset(cutlass::make_Coord(size(layout_B))); + block_C.reset(cutlass::make_Coord(size(layout_C))); + block_D.reset(cutlass::make_Coord(size(layout_D))); + block_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD)))); + block_reference_D.reset(cutlass::make_Coord(size(layout_D))); + block_reference_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD)))); + block_Normconst.reset(cutlass::make_Coord(1)); + block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA)))); + block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB)))); + initialize_block(block_B.host_view(), seed + 2022); + initialize_block(block_C.host_view(), seed + 2023); + initialize_block(block_SFA.host_view(), seed + 2024); + initialize_block(block_SFB.host_view(), seed + 2025); + block_Normconst.at(cutlass::make_Coord(0)) = 2; + block_B.sync_device(); + block_C.sync_device(); + block_SFA.sync_device(); + block_SFB.sync_device(); + block_SFD.sync_device(); + block_Normconst.sync_device(); + return true; +} +// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, 1}, + { // Mainloop arguments + block_A.device_data(), layout_A, + block_B.device_data(), stride_B, + block_E.device_data(), layout_E, + block_SFA.device_data(), layout_SFA, + block_SFB.device_data(), layout_SFB + }, + { // Epilogue arguments + {options.alpha, options.beta}, + block_C.device_data(), stride_C, + block_D.device_data(), stride_D + } + }; + arguments.epilogue.thread.block_scale_factor_ptr = block_SFD.device_data(); + arguments.epilogue.thread.norm_constant_ptr = block_Normconst.device_data(); + return arguments; +} +bool verify(const Options &options) { + using namespace cute; + // Create the arguments for host reference implementation + Tensor tensor_A = make_tensor(make_iterator(block_A_Decompressed.host_data()), layout_A); + Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA); + Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B); + Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB); + Tensor tensor_E = make_tensor(make_iterator(block_E.host_data()), layout_E); + + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, // ElementAccumulator + decltype(tensor_A), // TensorA + decltype(tensor_SFA), // TensorSfA + decltype(tensor_B), // TensorB + decltype(tensor_SFB) // TensorSfB + > mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB}; + auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C); + auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D); + auto tensor_SFD = cute::make_tensor(block_reference_SFD.host_data(), layout_SFD); + + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementAccumulator, // ElementScalar + ElementAccumulator, // ElementAccumulator + ElementAccumulator, // ElementCompute + decltype(tensor_C), // TensorC + decltype(tensor_D), // TensorD + decltype(tensor_SFD), // TensorSfD + cute::Int, + cutlass::reference::host::SfStrategy::SfDGen + > epilogue_params{options.alpha, options.beta, tensor_C, tensor_D, tensor_SFD, block_Normconst.at(cutlass::make_Coord(0))}; + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + // Comparison + block_D.sync_host(); + + bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_reference_D.host_view()); + passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0); + passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0); + return passed; +} +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + // Initialization + if(!initialize(options)) + { + std::cerr << " Initialization failed! " << std::endl; + exit(-1); + } + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + cudaDeviceSynchronize(); + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + if (!result.passed) { + exit(-1); + } + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + return 0; +} +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +/////////////////////////////////////////////////////////////////////////////////////////////////// +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example + // and must have compute capability at least 120. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + + if (!(props.major == 12 && props.minor == 0)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl; + return 0; + } + // + // Parse options + // + Options options; + options.parse(argc, args); + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + // + // Evaluate CUTLASS kernels + // +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + return 0; +} +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/80_blackwell_geforce_sparse_gemm/CMakeLists.txt b/examples/80_blackwell_geforce_sparse_gemm/CMakeLists.txt new file mode 100644 index 0000000000..6a94fb0d90 --- /dev/null +++ b/examples/80_blackwell_geforce_sparse_gemm/CMakeLists.txt @@ -0,0 +1,41 @@ +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +if (CUTLASS_NVCC_ARCHS MATCHES 120a) +cutlass_example_add_executable( + 80a_blackwell_geforce_mxfp8_bf16_sparse_gemm + 80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu +) + +cutlass_example_add_executable( + 80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm + 80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu +) + +endif() diff --git a/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu b/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu index 3148d2aac2..10cfe89d3c 100644 --- a/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu +++ b/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu @@ -30,11 +30,9 @@ **************************************************************************************************/ /*! \file - \brief A FP8 blockwise scaled GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS. + \brief An FP8 blockwise scaled GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS. */ - - #include #include "cutlass/cutlass.h" @@ -115,7 +113,7 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutC, AlignmentD, - cutlass::epilogue::TmaWarpSpecialized1Sm + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -125,7 +123,7 @@ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder ElementAccumulator, MmaTileShape_MNK, ClusterShape_MNK, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100 // Note: Groupwise and Blockwise only support 1 SM MMA at this moment + cutlass::gemm::KernelScheduleSm100Blockwise >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -222,8 +220,7 @@ struct Options { } /// Compute performance in GFLOP/s - double gflops(double runtime_s) const - { + double gflops(double runtime_s) const { // Two flops per multiply-add uint64_t flop = uint64_t(2) * m * n * k; double gflop = double(flop) / double(1.0e9); @@ -232,8 +229,7 @@ struct Options { }; /// Result structure -struct Result -{ +struct Result { double avg_runtime_ms; double gflops; cutlass::Status status; @@ -273,13 +269,16 @@ bool initialize_tensor( if (bits_input == 1) { scope_max = 2; scope_min = 0; - } else if (bits_input <= 8) { + } + else if (bits_input <= 8) { scope_max = 2; scope_min = -2; - } else if (bits_output == 16) { + } + else if (bits_output == 16) { scope_max = 5; scope_min = -5; - } else { + } + else { scope_max = 8; scope_min = -8; } @@ -392,8 +391,7 @@ void initialize(const Options &options) { } /// Populates a Gemm::Arguments structure from the given commandline options -typename Gemm::Arguments args_from_options(const Options &options) -{ +typename Gemm::Arguments args_from_options(const Options &options) { typename Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, {options.m, options.n, options.k, options.l}, @@ -468,8 +466,7 @@ bool verify(const Options &options) { /// Execute a given example GEMM computation template -int run(Options &options) -{ +int run(Options &options) { initialize(options); @@ -510,8 +507,7 @@ int run(Options &options) } // Run profiling loop - if (options.iterations > 0) - { + if (options.iterations > 0) { GpuTimer timer; timer.start(); for (int iter = 0; iter < options.iterations; ++iter) { diff --git a/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu b/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu index 11083e0981..6d8d1de019 100644 --- a/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu +++ b/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu @@ -30,7 +30,7 @@ **************************************************************************************************/ /*! \file - \brief A FP8 groupwise scaled GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS. + \brief An FP8 groupwise scaled GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS. */ #include @@ -96,9 +96,9 @@ using ElementCompute = float; // MMA and Cluster Tile Shapes // Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0 -using MmaTileShape_MNK = Shape<_128,_128,_128>; +using MmaTileShape_MNK = Shape<_256,_128,_128>; // Shape of the threadblocks in a cluster -using ClusterShape_MNK = Shape<_1,_1,_1>; +using ClusterShape_MNK = Shape<_2,_1,_1>; constexpr int ScaleGranularityM = 1; constexpr int ScaleGranularityN = 128; @@ -120,7 +120,7 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutC, AlignmentD, - cutlass::epilogue::TmaWarpSpecialized1Sm + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -130,7 +130,7 @@ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder ElementAccumulator, MmaTileShape_MNK, ClusterShape_MNK, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100 // Note: Groupwise and Blockwise only support 1 SM MMA at this moment + cutlass::gemm::KernelScheduleSm100Blockwise >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -227,8 +227,7 @@ struct Options { } /// Compute performance in GFLOP/s - double gflops(double runtime_s) const - { + double gflops(double runtime_s) const { // Two flops per multiply-add uint64_t flop = uint64_t(2) * m * n * k; double gflop = double(flop) / double(1.0e9); @@ -237,8 +236,7 @@ struct Options { }; /// Result structure -struct Result -{ +struct Result { double avg_runtime_ms; double gflops; cutlass::Status status; @@ -278,13 +276,16 @@ bool initialize_tensor( if (bits_input == 1) { scope_max = 2; scope_min = 0; - } else if (bits_input <= 8) { + } + else if (bits_input <= 8) { scope_max = 2; scope_min = -2; - } else if (bits_output == 16) { + } + else if (bits_output == 16) { scope_max = 5; scope_min = -5; - } else { + } + else { scope_max = 8; scope_min = -8; } @@ -397,9 +398,8 @@ void initialize(const Options &options) { } /// Populates a Gemm::Arguments structure from the given commandline options -typename Gemm::Arguments args_from_options(const Options &options) -{ - typename Gemm::Arguments arguments{ +typename Gemm::Arguments args_from_options(const Options &options) { + typename Gemm::Arguments arguments { cutlass::gemm::GemmUniversalMode::kGemm, {options.m, options.n, options.k, options.l}, {tensor_A.device_data(), stride_A, @@ -473,8 +473,7 @@ bool verify(const Options &options) { /// Execute a given example GEMM computation template -int run(Options &options) -{ +int run(Options &options) { initialize(options); @@ -515,8 +514,7 @@ int run(Options &options) } // Run profiling loop - if (options.iterations > 0) - { + if (options.iterations > 0) { GpuTimer timer; timer.start(); for (int iter = 0; iter < options.iterations; ++iter) { diff --git a/examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_blockwise.cu b/examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_blockwise.cu new file mode 100644 index 0000000000..b43869e7f1 --- /dev/null +++ b/examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_blockwise.cu @@ -0,0 +1,754 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief An FP8 blockwise-scaled grouped GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS. + In this example M, N, and K are fixed across groups. +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" + +#include "cutlass/util/reference/host/gett.hpp" + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +// A matrix configuration +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = AlignmentC; + +// MMA type +using ElementAccumulator = float; +using ElementCompute = float; + +// MMA and Cluster Tile Shapes +// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0 +using MmaTileShape_MNK = Shape<_128,_128,_128>; +// Shape of the threadblocks in a cluster +using ClusterShape_MNK = Shape<_1,_1,_1>; +// Shape of the tile computed by each SM + +using ScaleConfig = decltype(cutlass::detail::sm100_trivial_blockwise_scale_config(MmaTileShape_MNK{})); + +using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand +using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutC *, AlignmentD, + cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, cute::tuple, AlignmentA, + ElementB, cute::tuple, AlignmentB, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100 + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + void>; // Default to ClusterLaunchControl (CLC) based tile scheduler + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideA = typename Gemm::GemmKernel::InternalStrideA; +using StrideB = typename Gemm::GemmKernel::InternalStrideB; +using StrideC = typename Gemm::GemmKernel::InternalStrideC; +using StrideD = typename Gemm::GemmKernel::InternalStrideD; + +static_assert(cute::is_same_v); +static_assert(cute::is_same_v); + +/// Initialization +uint64_t seed; + +// Host-side allocations +std::vector offset_A; +std::vector offset_B; +std::vector offset_C; +std::vector offset_D; +std::vector offset_SFA; +std::vector offset_SFB; + +std::vector stride_A_host; +std::vector stride_B_host; +std::vector stride_C_host; +std::vector stride_D_host; +std::vector layout_SFA_host; +std::vector layout_SFB_host; + +std::vector ptr_ref_D_host; + +std::vector ptr_A_host; +std::vector ptr_B_host; +std::vector ptr_C_host; +std::vector ptr_D_host; +std::vector ptr_SFA_host; +std::vector ptr_SFB_host; + +// Shared Allocations + +cutlass::HostTensor block_A; +cutlass::HostTensor block_B; +cutlass::HostTensor block_C; +cutlass::HostTensor block_D; +cutlass::HostTensor block_ref_D; +cutlass::HostTensor block_SFA; +cutlass::HostTensor block_SFB; + +// Device-side allocations +cutlass::DeviceAllocation problem_sizes; + +cutlass::DeviceAllocation ptr_A; +cutlass::DeviceAllocation ptr_B; +cutlass::DeviceAllocation ptr_C; +cutlass::DeviceAllocation ptr_D; +cutlass::DeviceAllocation ptr_SFA; +cutlass::DeviceAllocation ptr_SFB; + +cutlass::DeviceAllocation stride_A; +cutlass::DeviceAllocation stride_B; +cutlass::DeviceAllocation stride_C; +cutlass::DeviceAllocation stride_D; +cutlass::DeviceAllocation layout_SFA; +cutlass::DeviceAllocation layout_SFB; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help = false; + bool skip_verification = false; + + float alpha = 1.f, beta = 0.f; + int iterations = 1000; + int m = 1024, n = 2048, k = 512, groups = 10; + std::vector problem_sizes_host; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + if (cmd.check_cmd_line_flag("skip-verification")) { + skip_verification = true; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("groups", groups); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + + for (int i = 0; i < groups; ++i) { + problem_sizes_host.push_back({m, n, k}); + } + + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "81_blackwell_grouped_gemm_blockwise\n\n" + << " Blackwell FP8 GEMM with Blockwise Scaling using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --groups= Sets the number of individual GEMM problems for Grouped GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --iterations= Number of profiling iterations to perform.\n\n" + << " --skip-verification Skip verification.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "81_blackwell_grouped_gemm_blockwise" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * groups; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result { + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } + else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + throw std::runtime_error("Not implementated."); + } + + return true; +} + +/// Helper to initialize a block of device data (scale_tensors) +template +bool initialize_scale_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + + scope_min = -1; + scope_max = 1; + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + throw std::runtime_error("Not implementated."); + } + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(Options const& options) { + int32_t total_elements_A = 0; + int32_t total_elements_B = 0; + int32_t total_elements_C = 0; + int32_t total_elements_D = 0; + int32_t total_elements_SFA = 0; + int32_t total_elements_SFB = 0; + + for (int32_t i = 0; i < options.groups; ++i) { + + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + offset_SFA.push_back(total_elements_SFA); + offset_SFB.push_back(total_elements_SFB); + + int32_t elements_A = M * K; + int32_t elements_B = K * N; + int32_t elements_C = M * N; + int32_t elements_D = M * N; + + auto gemm_layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1)); + auto gemm_layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1)); + + int32_t elements_SFA = cosize(gemm_layout_SFA); + int32_t elements_SFB = cosize(gemm_layout_SFB); + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + total_elements_SFA += elements_SFA; + total_elements_SFB += elements_SFB; + + stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1})); + stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); + stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1})); + stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1})); + layout_SFA_host.push_back(gemm_layout_SFA); + layout_SFB_host.push_back(gemm_layout_SFB); + } + + block_A.resize(cutlass::make_Coord(total_elements_A)); + block_B.resize(cutlass::make_Coord(total_elements_B)); + block_C.resize(cutlass::make_Coord(total_elements_C)); + block_D.resize(cutlass::make_Coord(total_elements_D)); + block_ref_D.resize(cutlass::make_Coord(total_elements_D)); + block_SFA.resize(cutlass::make_Coord(total_elements_SFA)); + block_SFB.resize(cutlass::make_Coord(total_elements_SFB)); + + initialize_tensor(block_A.host_view(), cutlass::Distribution::Uniform, seed + 2022); + initialize_tensor(block_B.host_view(), cutlass::Distribution::Uniform, seed + 2023); + initialize_tensor(block_C.host_view(), cutlass::Distribution::Uniform, seed + 2024); + initialize_scale_tensor(block_SFA.host_view(), cutlass::Distribution::Uniform, seed + 2026); + initialize_scale_tensor(block_SFB.host_view(), cutlass::Distribution::Uniform, seed + 2027); + + block_A.sync_device(); + block_B.sync_device(); + block_C.sync_device(); + block_SFA.sync_device(); + block_SFB.sync_device(); + + // copy problem sizes + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_host.data()); + + std::vector device_ptr_A_host(options.groups); + std::vector device_ptr_B_host(options.groups); + std::vector device_ptr_C_host(options.groups); + std::vector device_ptr_D_host(options.groups); + std::vector device_ptr_SFA_host(options.groups); + std::vector device_ptr_SFB_host(options.groups); + + ptr_A_host = std::vector(options.groups); + ptr_B_host = std::vector(options.groups); + ptr_C_host = std::vector(options.groups); + ptr_D_host = std::vector(options.groups); + ptr_SFA_host = std::vector(options.groups); + ptr_SFB_host = std::vector(options.groups); + ptr_ref_D_host = std::vector(options.groups); + + for (int32_t i = 0; i < options.groups; ++i) { + // Ptrs for A + ptr_A_host.at(i) = block_A.host_data() + offset_A.at(i); + device_ptr_A_host.at(i) = block_A.device_data() + offset_A.at(i); + + // Ptrs for B + ptr_B_host.at(i) = block_B.host_data() + offset_B.at(i); + device_ptr_B_host.at(i) = block_B.device_data() + offset_B.at(i); + + // Ptrs for C + ptr_C_host.at(i) = block_C.host_data() + offset_C.at(i); + device_ptr_C_host.at(i) = block_C.device_data() + offset_C.at(i); + + // Ptrs for D + ptr_D_host.at(i) = block_D.host_data() + offset_D.at(i); + device_ptr_D_host.at(i) = block_D.device_data() + offset_D.at(i); + ptr_ref_D_host.at(i) = block_ref_D.host_data() + offset_D.at(i); + + // Ptrs for SFA + ptr_SFA_host.at(i) = block_SFA.host_data() + offset_SFA.at(i); + device_ptr_SFA_host.at(i) = block_SFA.device_data() + offset_SFA.at(i); + + // Ptrs for SFB + ptr_SFB_host.at(i) = block_SFB.host_data() + offset_SFB.at(i); + device_ptr_SFB_host.at(i) = block_SFB.device_data() + offset_SFB.at(i); + } + + ptr_A.reset(options.groups); + ptr_A.copy_from_host(device_ptr_A_host.data()); + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(device_ptr_B_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(device_ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(device_ptr_D_host.data()); + + ptr_SFA.reset(options.groups); + ptr_SFA.copy_from_host(device_ptr_SFA_host.data()); + + ptr_SFB.reset(options.groups); + ptr_SFB.copy_from_host(device_ptr_SFB_host.data()); + + stride_A.reset(options.groups); + stride_A.copy_from_host(stride_A_host.data()); + + stride_B.reset(options.groups); + stride_B.copy_from_host(stride_B_host.data()); + + stride_C.reset(options.groups); + stride_C.copy_from_host(stride_C_host.data()); + + stride_D.reset(options.groups); + stride_D.copy_from_host(stride_D_host.data()); + + layout_SFA.reset(options.groups); + layout_SFA.copy_from_host(layout_SFA_host.data()); + + layout_SFB.reset(options.groups); + layout_SFB.copy_from_host(layout_SFB_host.data()); + +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) { + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, + {ptr_A.get(), stride_A.get(), + ptr_B.get(), stride_B.get(), + ptr_SFA.get(), layout_SFA.get(), + ptr_SFB.get(), layout_SFB.get() + }, + { + {}, // epilogue.thread + ptr_C.get(), stride_C.get(), + ptr_D.get(), stride_D.get() + }, + hw_info + }; + + auto &fusion_args = arguments.epilogue.thread; + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + + return arguments; +} + +bool verify(const Options &options) { + // + // Compute reference output + // + + block_D.sync_host(); + + for (int i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + // Create instantiation for device reference gemm kernel + auto A = cute::make_tensor(ptr_A_host.at(i), + cute::make_layout(cute::make_shape(M, K, 1), stride_A_host.at(i))); + auto B = cute::make_tensor(ptr_B_host.at(i), + cute::make_layout(cute::make_shape(N, K, 1), stride_B_host.at(i))); + auto C = cute::make_tensor(ptr_C_host.at(i), + cute::make_layout(cute::make_shape(M, N, 1), stride_C_host.at(i))); + auto D = cute::make_tensor(ptr_ref_D_host.at(i), + cute::make_layout(cute::make_shape(M, N, 1), stride_D_host.at(i))); + + auto SFA = cute::make_tensor(ptr_SFA_host.at(i), layout_SFA_host.at(i)); + auto SFB = cute::make_tensor(ptr_SFB_host.at(i), layout_SFB_host.at(i)); + + using unused_t = decltype(D); + + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, + decltype(A), + decltype(SFA), + decltype(B), + decltype(SFB) + > mainloop_params{A, SFA, B, SFB}; + + cutlass::reference::host::GettEpilogueParams< + ElementAccumulator, + ElementAccumulator, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D) + > epilogue_params; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = options.alpha; + epilogue_params.beta = options.beta; + + // get reference result + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + } + + bool passed = cutlass::reference::host::TensorEquals(block_ref_D.host_view(), block_D.host_view()); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) { + initialize(options); + + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + Result result; + if (!options.skip_verification) { + // Check if output from CUTLASS kernel and reference kernel are equal or not + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + } + + // Run profiling loop + if (options.iterations > 0) { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.groups << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least sm100a. + + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 10 || props.minor != 0) { + std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl; + return 0; + } + + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Run + // +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_groupwise.cu b/examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_groupwise.cu new file mode 100644 index 0000000000..60667cda29 --- /dev/null +++ b/examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_groupwise.cu @@ -0,0 +1,761 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief An FP8 blockwise-scaled grouped GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS. + In this example M, N, and K are fixed across groups. +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" + +#include "cutlass/util/reference/host/gett.hpp" + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +// A matrix configuration +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = AlignmentC; + +// MMA type +using ElementAccumulator = float; +using ElementCompute = float; + +// MMA and Cluster Tile Shapes +// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0 +using MmaTileShape_MNK = Shape<_256,_128,_128>; +// Shape of the threadblocks in a cluster +using ClusterShape_MNK = Shape<_2,_1,_1>; +// Shape of the threadblocks participating in a tcgen05 MMA. <1, 1, 1> for cta_group = 1, <2, 1, 1> for cta_group = 2 + +constexpr int ScaleGranularityM = 1; +constexpr int ScaleGranularityN = 128; +constexpr int ScaleGranularityK = 128; +using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig; + +// Note when we have multiple scale factors per tile (in this case 128 scales in M per tile), we will restrict up to a +// 16B alignment if possible (i.e., we have at least 16B of scales in M). +// In this case the smallest M that can be executed is 16. To avoid this for smaller M, you can swap A and B +// and transpose A, B, C, and scales since B^T A^T = C^T. +using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand +using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutC *, AlignmentD, + cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, cute::tuple, AlignmentA, + ElementB, cute::tuple, AlignmentB, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise2SmSm100 + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + void>; // Default to ClusterLaunchControl (CLC) based tile scheduler + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideA = typename Gemm::GemmKernel::InternalStrideA; +using StrideB = typename Gemm::GemmKernel::InternalStrideB; +using StrideC = typename Gemm::GemmKernel::InternalStrideC; +using StrideD = typename Gemm::GemmKernel::InternalStrideD; + +static_assert(cute::is_same_v); +static_assert(cute::is_same_v); + +/// Initialization +uint64_t seed; + +// Host-side allocations +std::vector offset_A; +std::vector offset_B; +std::vector offset_C; +std::vector offset_D; +std::vector offset_SFA; +std::vector offset_SFB; + +std::vector stride_A_host; +std::vector stride_B_host; +std::vector stride_C_host; +std::vector stride_D_host; +std::vector layout_SFA_host; +std::vector layout_SFB_host; + +std::vector ptr_ref_D_host; + +std::vector ptr_A_host; +std::vector ptr_B_host; +std::vector ptr_C_host; +std::vector ptr_D_host; +std::vector ptr_SFA_host; +std::vector ptr_SFB_host; + +// Shared Allocations + +cutlass::HostTensor block_A; +cutlass::HostTensor block_B; +cutlass::HostTensor block_C; +cutlass::HostTensor block_D; +cutlass::HostTensor block_ref_D; +cutlass::HostTensor block_SFA; +cutlass::HostTensor block_SFB; + +// Device-side allocations +cutlass::DeviceAllocation problem_sizes; + +cutlass::DeviceAllocation ptr_A; +cutlass::DeviceAllocation ptr_B; +cutlass::DeviceAllocation ptr_C; +cutlass::DeviceAllocation ptr_D; +cutlass::DeviceAllocation ptr_SFA; +cutlass::DeviceAllocation ptr_SFB; + +cutlass::DeviceAllocation stride_A; +cutlass::DeviceAllocation stride_B; +cutlass::DeviceAllocation stride_C; +cutlass::DeviceAllocation stride_D; +cutlass::DeviceAllocation layout_SFA; +cutlass::DeviceAllocation layout_SFB; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help = false; + bool skip_verification = false; + + float alpha = 1.f, beta = 0.f; + int iterations = 1000; + int m = 1024, n = 2048, k = 512, groups = 10; + std::vector problem_sizes_host; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + if (cmd.check_cmd_line_flag("skip-verification")) { + skip_verification = true; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("groups", groups); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + + for (int i = 0; i < groups; ++i) { + problem_sizes_host.push_back({m, n, k}); + } + + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "81_blackwell_grouped_gemm_groupwise\n\n" + << " Blackwell FP8 GEMM with Groupwise Scaling using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --groups= Sets the number of individual GEMM problems for Grouped GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --iterations= Number of profiling iterations to perform.\n\n" + << " --skip-verification Skip verification.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "81_blackwell_grouped_gemm_groupwise" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * groups; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result { + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } + else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + throw std::runtime_error("Not implementated."); + } + + return true; +} + +/// Helper to initialize a block of device data (scale_tensors) +template +bool initialize_scale_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + + scope_min = -1; + scope_max = 1; + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + throw std::runtime_error("Not implementated."); + } + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + int32_t total_elements_A = 0; + int32_t total_elements_B = 0; + int32_t total_elements_C = 0; + int32_t total_elements_D = 0; + int32_t total_elements_SFA = 0; + int32_t total_elements_SFB = 0; + + for (int32_t i = 0; i < options.groups; ++i) { + + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + offset_SFA.push_back(total_elements_SFA); + offset_SFB.push_back(total_elements_SFB); + + int32_t elements_A = M * K; + int32_t elements_B = K * N; + int32_t elements_C = M * N; + int32_t elements_D = M * N; + + auto gemm_layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1)); + auto gemm_layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1)); + + int32_t elements_SFA = cosize(gemm_layout_SFA); + int32_t elements_SFB = cosize(gemm_layout_SFB); + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + total_elements_SFA += elements_SFA; + total_elements_SFB += elements_SFB; + + stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1})); + stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); + stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1})); + stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1})); + layout_SFA_host.push_back(gemm_layout_SFA); + layout_SFB_host.push_back(gemm_layout_SFB); + } + + block_A.resize(cutlass::make_Coord(total_elements_A)); + block_B.resize(cutlass::make_Coord(total_elements_B)); + block_C.resize(cutlass::make_Coord(total_elements_C)); + block_D.resize(cutlass::make_Coord(total_elements_D)); + block_ref_D.resize(cutlass::make_Coord(total_elements_D)); + block_SFA.resize(cutlass::make_Coord(total_elements_SFA)); + block_SFB.resize(cutlass::make_Coord(total_elements_SFB)); + + initialize_tensor(block_A.host_view(), cutlass::Distribution::Uniform, seed + 2022); + initialize_tensor(block_B.host_view(), cutlass::Distribution::Uniform, seed + 2023); + initialize_tensor(block_C.host_view(), cutlass::Distribution::Uniform, seed + 2024); + initialize_scale_tensor(block_SFA.host_view(), cutlass::Distribution::Uniform, seed + 2026); + initialize_scale_tensor(block_SFB.host_view(), cutlass::Distribution::Uniform, seed + 2027); + + block_A.sync_device(); + block_B.sync_device(); + block_C.sync_device(); + block_SFA.sync_device(); + block_SFB.sync_device(); + + // copy problem sizes + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_host.data()); + + std::vector device_ptr_A_host(options.groups); + std::vector device_ptr_B_host(options.groups); + std::vector device_ptr_C_host(options.groups); + std::vector device_ptr_D_host(options.groups); + std::vector device_ptr_SFA_host(options.groups); + std::vector device_ptr_SFB_host(options.groups); + + ptr_A_host = std::vector(options.groups); + ptr_B_host = std::vector(options.groups); + ptr_C_host = std::vector(options.groups); + ptr_D_host = std::vector(options.groups); + ptr_SFA_host = std::vector(options.groups); + ptr_SFB_host = std::vector(options.groups); + ptr_ref_D_host = std::vector(options.groups); + + for (int32_t i = 0; i < options.groups; ++i) { + // Ptrs for A + ptr_A_host.at(i) = block_A.host_data() + offset_A.at(i); + device_ptr_A_host.at(i) = block_A.device_data() + offset_A.at(i); + + // Ptrs for B + ptr_B_host.at(i) = block_B.host_data() + offset_B.at(i); + device_ptr_B_host.at(i) = block_B.device_data() + offset_B.at(i); + + // Ptrs for C + ptr_C_host.at(i) = block_C.host_data() + offset_C.at(i); + device_ptr_C_host.at(i) = block_C.device_data() + offset_C.at(i); + + // Ptrs for D + ptr_D_host.at(i) = block_D.host_data() + offset_D.at(i); + device_ptr_D_host.at(i) = block_D.device_data() + offset_D.at(i); + ptr_ref_D_host.at(i) = block_ref_D.host_data() + offset_D.at(i); + + // Ptrs for SFA + ptr_SFA_host.at(i) = block_SFA.host_data() + offset_SFA.at(i); + device_ptr_SFA_host.at(i) = block_SFA.device_data() + offset_SFA.at(i); + + // Ptrs for SFB + ptr_SFB_host.at(i) = block_SFB.host_data() + offset_SFB.at(i); + device_ptr_SFB_host.at(i) = block_SFB.device_data() + offset_SFB.at(i); + } + + ptr_A.reset(options.groups); + ptr_A.copy_from_host(device_ptr_A_host.data()); + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(device_ptr_B_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(device_ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(device_ptr_D_host.data()); + + ptr_SFA.reset(options.groups); + ptr_SFA.copy_from_host(device_ptr_SFA_host.data()); + + ptr_SFB.reset(options.groups); + ptr_SFB.copy_from_host(device_ptr_SFB_host.data()); + + stride_A.reset(options.groups); + stride_A.copy_from_host(stride_A_host.data()); + + stride_B.reset(options.groups); + stride_B.copy_from_host(stride_B_host.data()); + + stride_C.reset(options.groups); + stride_C.copy_from_host(stride_C_host.data()); + + stride_D.reset(options.groups); + stride_D.copy_from_host(stride_D_host.data()); + + layout_SFA.reset(options.groups); + layout_SFA.copy_from_host(layout_SFA_host.data()); + + layout_SFB.reset(options.groups); + layout_SFB.copy_from_host(layout_SFB_host.data()); + +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) { + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, + {ptr_A.get(), stride_A.get(), + ptr_B.get(), stride_B.get(), + ptr_SFA.get(), layout_SFA.get(), + ptr_SFB.get(), layout_SFB.get() + }, + { + {}, // epilogue.thread + ptr_C.get(), stride_C.get(), + ptr_D.get(), stride_D.get() + }, + hw_info + }; + + auto &fusion_args = arguments.epilogue.thread; + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + + return arguments; +} + +bool verify(const Options &options) { + // + // Compute reference output + // + + block_D.sync_host(); + + for (int i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + // Create instantiation for device reference gemm kernel + auto A = cute::make_tensor(ptr_A_host.at(i), + cute::make_layout(cute::make_shape(M, K, 1), stride_A_host.at(i))); + auto B = cute::make_tensor(ptr_B_host.at(i), + cute::make_layout(cute::make_shape(N, K, 1), stride_B_host.at(i))); + auto C = cute::make_tensor(ptr_C_host.at(i), + cute::make_layout(cute::make_shape(M, N, 1), stride_C_host.at(i))); + auto D = cute::make_tensor(ptr_ref_D_host.at(i), + cute::make_layout(cute::make_shape(M, N, 1), stride_D_host.at(i))); + + auto SFA = cute::make_tensor(ptr_SFA_host.at(i), layout_SFA_host.at(i)); + auto SFB = cute::make_tensor(ptr_SFB_host.at(i), layout_SFB_host.at(i)); + + using unused_t = decltype(D); + + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, + decltype(A), + decltype(SFA), + decltype(B), + decltype(SFB) + > mainloop_params{A, SFA, B, SFB}; + + cutlass::reference::host::GettEpilogueParams< + ElementAccumulator, + ElementAccumulator, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D) + > epilogue_params; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = options.alpha; + epilogue_params.beta = options.beta; + + // get reference result + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + } + + bool passed = cutlass::reference::host::TensorEquals(block_ref_D.host_view(), block_D.host_view()); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) { + initialize(options); + + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + Result result; + if (!options.skip_verification) { + // Check if output from CUTLASS kernel and reference kernel are equal or not + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + } + + // Run profiling loop + if (options.iterations > 0) { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.groups << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least sm100a. + + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 10 || props.minor != 0) { + std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl; + return 0; + } + + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Run + // +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/81_blackwell_gemm_blockwise/CMakeLists.txt b/examples/81_blackwell_gemm_blockwise/CMakeLists.txt index a4dc34d09e..8b98154627 100644 --- a/examples/81_blackwell_gemm_blockwise/CMakeLists.txt +++ b/examples/81_blackwell_gemm_blockwise/CMakeLists.txt @@ -54,4 +54,18 @@ cutlass_example_add_executable( TEST_SMALL ) +cutlass_example_add_executable( + 81_blackwell_grouped_gemm_blockwise + 81_blackwell_grouped_gemm_blockwise.cu + TEST_COMMAND_OPTIONS + TEST_SMALL +) + +cutlass_example_add_executable( + 81_blackwell_grouped_gemm_groupwise + 81_blackwell_grouped_gemm_groupwise.cu + TEST_COMMAND_OPTIONS + TEST_SMALL +) + endif() diff --git a/examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu b/examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu new file mode 100644 index 0000000000..f955b8e99b --- /dev/null +++ b/examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu @@ -0,0 +1,869 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Distributed GEMM (DistGEMM) for Blackwell. + + This example runs Tensor Parallel GEMMs using the (experimental) Distributed GEMM API in + CUTLASS. For more information, please refer to README.md. + + Note that Distributed GEMM assumes an any-to-any NVLink network topology. + To check whether your device is compatible, run: + + $ nvidia-smi topo -m + + and make sure there's an any-to-any NVLink topology. It would look like this: + + GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 + GPU0 X NV18 NV18 NV18 NV18 NV18 NV18 NV18 + GPU1 NV18 X NV18 NV18 NV18 NV18 NV18 NV18 + GPU2 NV18 NV18 X NV18 NV18 NV18 NV18 NV18 + GPU3 NV18 NV18 NV18 X NV18 NV18 NV18 NV18 + GPU4 NV18 NV18 NV18 NV18 X NV18 NV18 NV18 + GPU5 NV18 NV18 NV18 NV18 NV18 X NV18 NV18 + GPU6 NV18 NV18 NV18 NV18 NV18 NV18 X NV18 + GPU7 NV18 NV18 NV18 NV18 NV18 NV18 NV18 X + + You should also additionally check if the driver enables peer to peer access: + + $ nvidia-smi topo -p2p r + + Output should be something like this: + + GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 + GPU0 X OK OK OK OK OK OK OK + GPU1 OK X OK OK OK OK OK OK + GPU2 OK OK X OK OK OK OK OK + GPU3 OK OK OK X OK OK OK OK + GPU4 OK OK OK OK X OK OK OK + GPU5 OK OK OK OK OK X OK OK + GPU6 OK OK OK OK OK OK X OK + GPU7 OK OK OK OK OK OK OK X + + It is recommended to build this target with the following flag to enable + Grid Dependency Control instructions (GDC) in CUTLASS: + - CUTLASS_ENABLE_GDC_FOR_SM100 + + Example: + + $ mkdir build && cd build + + $ cmake .. -DCUTLASS_NVCC_ARCHS="100a" -DCUTLASS_ENABLE_GDC_FOR_SM100=1 + + $ cd examples/82_blackwell_distributed_gemm + + $ make + + $ ./82_blackwell_distributed_gemm +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/error_metrics.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" + +// Distributed GEMM headers +#include "cutlass/experimental/distributed/device/dist_gemm_universal_wrapper.hpp" +#include "cutlass/experimental/distributed/kernel/dist_gemm_kernel_wrapper.hpp" +#include "cutlass/experimental/distributed/schedules/dist_gemm_1d_schedules.hpp" + +#include "helper.h" + +// Distributed GEMM helpers +#include "dist_gemm_helpers.h" + +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Distributed GEMM configuration +///////////////////////////////////////////////////////////////////////////////////////////////// + +// TP size (= number of processors/GPUs) +using TP = _8; +static constexpr int TP_ = TP{}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && \ + (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4)) + +// Distributed GEMM tiling/sharding schedule +// Choices: +// +// * All Gather + GEMM: +// * AllGather1D_TilingCD_RotatingA +// * AllGather1D_TilingCD_RotatingB +// +// * GEMM + Reduce Scatter: +// * ReduceScatter1D_TilingA_RotatingC +// * ReduceScatter1D_TilingB_RotatingC + +using DistSchedule = cutlass::distributed::schedules::AllGather1D_TilingCD_RotatingA; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +using ElementD = cutlass::float_e4m3_t; // Element type for C and D matrix operands +using LayoutD = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) + +// Kernel functional config +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for epilogue computation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + +// MMA and Cluster Tile Shapes +// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0 +using MmaTileShape_MNK = Shape<_256,_256,_128>; +// Shape of the threadblocks in a cluster +using ClusterShape_MNK = Shape<_2,_1,_1>; +// Shape of the tile computed by each SM +using PerSmTileShape_MNK = Shape<_128, _256, _128>; + +// Build the epilogue +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + PerSmTileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + +// Build the mainloop +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 + >::CollectiveOp; + +// Compose into a kernel +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + void>; // Default to ClusterLaunchControl (CLC) based tile scheduler + +// We're going to use the single-device GEMM as reference +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Instantiate Distributed GEMM kernel +using DistGemmKernel = cutlass::distributed::kernel::DistributedGemmKernelWrapper< + GemmKernel, + DistSchedule +>; +using DistGemm = cutlass::distributed::device::DistributedGemmUniversalAdapter; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; +uint64_t seed; + +using HostTensorA = typename cutlass::HostTensor; +using HostTensorB = typename cutlass::HostTensor; +using HostTensorC = typename cutlass::HostTensor; +using HostTensorD = typename cutlass::HostTensor; + +// Reference GEMM tensors +HostTensorA tensor_A; +HostTensorB tensor_B; +HostTensorC tensor_C; +HostTensorD tensor_D; +HostTensorD tensor_ref_D; + +// DistGEMM tensors (multi-device) +HostTensorA tensor_A_arr[TP_]; +HostTensorB tensor_B_arr[TP_]; +HostTensorD tensor_C_arr[TP_]; +HostTensorD tensor_D_arr[TP_]; + +#endif // (defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4)) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help = false; + + float alpha = 1.f, beta = 0.f; + int iterations = 100; + int warmup_iterations = 10; + int m = 16384, n = 106496, k = 16384, l = 1; + float eps = 0.f; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha); + cmd.get_cmd_line_argument("beta", beta); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("warmup-iterations", warmup_iterations); + cmd.get_cmd_line_argument("eps", eps); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "82_blackwell_distributed_gemm\n\n" + << " Blackwell Distributed GEMM (DistGEMM). \n" + << " For more details please refer to the source file.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch) of the GEMM (default: 1)\n" + << " --alpha= Epilogue scalar alpha (default: 1.0)\n" + << " --beta= Epilogue scalar beta (default: 0.0)\n" + << " --iterations= Number of profiling iterations to perform (default: 100)\n" + << " --warmup-iterations= Number of warmup iterations prior to profiling (default: 10)\n" + << " --eps= Threshold for error compared to reference " + << "GEMM (default: 0.0)\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "82_blackwell_distributed_gemm" << " --m=16384 --n=106496 --k=16384 \n\n"; + + return out; + } + + /// Compute performance in TFLOP/s + double tflops(double runtime_s) const { + + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * l / TP_; + double tflop = double(flop) / double(1.0e12); + return tflop / runtime_s; + } +}; + +/// Result structure +struct Result { + double avg_runtime_ms; + double tflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double tflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), tflops(tflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && \ + (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4)) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_tensor( + cutlass::TensorView view, + uint64_t seed, + bool is_device_tensor = false) { + + double scope_max, scope_min; + int bits = cutlass::sizeof_bits::value; + + if (bits == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits <= 16) { + scope_max = 2; + scope_min = -2; + } + else { + scope_max = 8; + scope_min = -8; + } + + if (is_device_tensor) { + using Real = typename cutlass::RealType::Type; + cutlass::reference::device::TensorFillRandomUniform( + view, seed, static_cast(scope_max), static_cast(scope_min), 0); + cudaDeviceSynchronize(); + } else { + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + auto problem_shape = cute::make_tuple(options.m, options.n, options.k, options.l); + + // Setup (reference) GEMM tensors + auto shape_A = cute::select<0,2,3>(problem_shape); + auto shape_B = cute::select<1,2,3>(problem_shape); + auto shape_C = cute::select<0,1,3>(problem_shape); + auto shape_D = cute::select<0,1,3>(problem_shape); + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, shape_A); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, shape_C); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, shape_D); + + auto a_coord = cutlass::make_Coord(size(shape_A), 1); + auto b_coord = cutlass::make_Coord(size(shape_B), 1); + auto c_coord = cutlass::make_Coord(size(shape_C), 1); + + tensor_A.resize(a_coord); + tensor_B.resize(b_coord); + tensor_C.resize(c_coord); + tensor_D.resize(c_coord); + tensor_ref_D.resize(c_coord); + + initialize_tensor(tensor_A.device_view(), seed + 2022, /* is_device_tensor = */ true); + initialize_tensor(tensor_B.device_view(), seed + 2023, /* is_device_tensor = */ true); + initialize_tensor(tensor_C.device_view(), seed + 2024, /* is_device_tensor = */ true); + + tensor_A.sync_host(); + tensor_B.sync_host(); + tensor_C.sync_host(); + tensor_D.sync_host(); + tensor_ref_D.sync_host(); + + // Set up DistGEMM tensors + auto local_shape_A = DistSchedule::get_local_a_shape(problem_shape); + auto local_shape_B = DistSchedule::get_local_b_shape(problem_shape); + auto local_shape_C = DistSchedule::get_local_c_shape(problem_shape); + auto local_shape_D = DistSchedule::get_local_d_shape(problem_shape); + + auto a_coord_device = cutlass::make_Coord(size(local_shape_A), 1); + auto b_coord_device = cutlass::make_Coord(size(local_shape_B), 1); + auto c_coord_device = cutlass::make_Coord(size(local_shape_C), 1); + + int primary_device_idx; + CUDA_CHECK(cudaGetDevice(&primary_device_idx)); + + // Enable any-to-any access + for (int device_idx = 0; device_idx < TP_; ++device_idx) { + int can_access; + CUDA_CHECK(cudaSetDevice(device_idx)); + for (int peer_idx = 0; peer_idx < TP_; ++peer_idx) { + if (peer_idx != device_idx) { + CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access, device_idx, peer_idx)); + if (not can_access) { + std::cerr << "FAILURE: Device " << device_idx << " can't access device " << peer_idx << "." << + std::endl; + exit(EXIT_FAILURE); + } + CUDA_CHECK(cudaDeviceEnablePeerAccess(peer_idx, 0)); + } + } + + tensor_A_arr[device_idx].resize(a_coord_device); + tensor_B_arr[device_idx].resize(b_coord_device); + tensor_C_arr[device_idx].resize(c_coord_device); + tensor_D_arr[device_idx].resize(c_coord_device); + } + CUDA_CHECK(cudaSetDevice(primary_device_idx)); +} + +/// Commandline options -> Gemm/DistGemm Arguments +using GemmArguments = typename Gemm::Arguments; +GemmArguments gemm_args_from_options(const Options &options) { + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {tensor_A.device_data(), stride_A, tensor_B.device_data(), stride_B}, + { + {static_cast(options.alpha), static_cast(options.beta)}, + tensor_C.device_data(), stride_C, + tensor_ref_D.device_data(), stride_D + } + }; + + return arguments; +} + +using DistGemmArguments = typename DistGemm::Arguments; +DistGemmArguments dist_gemm_args_from_options( + const Options &options, + int device_idx, + cudaStream_t stream) { + + auto problem_shape = cute::make_tuple(options.m, options.n, options.k, options.l); + + auto global_A = cute::make_tensor(tensor_A.device_data(), + cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A)); + auto global_B = cute::make_tensor(tensor_B.device_data(), + cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B)); + auto global_C = cute::make_tensor(tensor_C.device_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C)); + + auto global_A_device_slice = DistSchedule::get_device_slice_A(global_A, device_idx); + auto global_B_device_slice = DistSchedule::get_device_slice_B(global_B, device_idx); + auto global_C_device_slice = DistSchedule::get_device_slice_C(global_C, device_idx); + + auto local_shape_A = DistSchedule::get_local_a_shape(problem_shape); + auto local_shape_B = DistSchedule::get_local_b_shape(problem_shape); + auto local_shape_C = DistSchedule::get_local_c_shape(problem_shape); + auto local_shape_D = DistSchedule::get_local_d_shape(problem_shape); + + auto local_stride_A = cutlass::make_cute_packed_stride(StrideA{}, local_shape_A); + auto local_stride_B = cutlass::make_cute_packed_stride(StrideB{}, local_shape_B); + auto local_stride_C = cutlass::make_cute_packed_stride(StrideC{}, local_shape_C); + auto local_stride_D = cutlass::make_cute_packed_stride(StrideD{}, local_shape_D); + + auto local_A = cute::make_tensor( + tensor_A_arr[device_idx].device_data(), + make_layout(local_shape_A, local_stride_A)); + auto local_B = cute::make_tensor( + tensor_B_arr[device_idx].device_data(), + make_layout(local_shape_B, local_stride_B)); + auto local_C = cute::make_tensor( + tensor_C_arr[device_idx].device_data(), + make_layout(local_shape_C, local_stride_C)); + auto local_D = cute::make_tensor( + tensor_D_arr[device_idx].device_data(), + make_layout(local_shape_D, local_stride_D)); + + // Copy over tensor tiles for the first iteration + cutlass::device_copy(global_A_device_slice, local_A, stream); + cutlass::device_copy(global_B_device_slice, local_B, stream); + cutlass::device_copy(global_C_device_slice, local_C, stream); + + DistGemmArguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, // mode + problem_shape, // problem shape + { + reinterpret_cast(local_A.data()), + local_A.stride(), + reinterpret_cast(local_B.data()), + local_B.stride() + }, // mainloop + { + { // epilogue.thread + static_cast(options.alpha), + static_cast(options.beta) + }, + reinterpret_cast(local_C.data()), + local_C.stride(), + reinterpret_cast(local_D.data()), + local_D.stride(), + }, // epilogue + {}, // hw_info + {} // scheduler + }; + + return arguments; +} + +// Gathers results, moves back to the original full-sized D tensor on the primary device. +void gather_results(const Options &options, int device_idx, cudaStream_t stream = nullptr) { + + auto problem_shape = cute::make_tuple(options.m, options.n, options.k, options.l); + + // Global dest + auto global_D = cute::make_tensor(tensor_D.device_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D)); + auto global_D_device_slice = DistSchedule::get_device_slice_D(global_D, device_idx); + + // Device_idx local dest + auto local_shape_D = DistSchedule::get_local_d_shape(problem_shape); + auto local_stride_D = cutlass::make_cute_packed_stride(StrideD{}, local_shape_D); + auto local_D = cute::make_tensor( + tensor_D_arr[device_idx].device_data(), + make_layout(local_shape_D, local_stride_D) + ); + + // Copy to global dest + cutlass::device_copy(local_D, global_D_device_slice, stream); +} + +bool verify(const Options &options) { + tensor_D.sync_host(); + tensor_ref_D.sync_host(); + + bool passed = false; + if (options.eps == 0.f) { + passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view()); + } else { + double err = cutlass::reference::host::TensorRelativeErrorMetric( + tensor_D.host_view(), + tensor_ref_D.host_view()); + passed = err < 1e-5; + } + + if (options.m <= 64 && options.n <= 64) { + std::cout << "GEMM output:\n" << tensor_D.host_view() << "\n\n"; + std::cout << "Reference output:\n" << tensor_ref_D.host_view() << "\n\n"; + } + + return passed; +} + +/// Execute a given example GEMM computation +int run(Options &options) { + + int primary_device_idx; + cudaError_t device_get_result = cudaGetDevice(&primary_device_idx); + if (device_get_result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() failed"); + } + + initialize(options); + + // Reference single-GPU GEMM + Gemm reference_gemm; + cutlass::device_memory::allocation reference_workspace; + + auto reference_arguments = gemm_args_from_options(options); + size_t reference_workspace_size = Gemm::get_workspace_size(reference_arguments); + reference_workspace = cutlass::device_memory::allocation(reference_workspace_size); + + CUTLASS_CHECK(reference_gemm.can_implement(reference_arguments)); + CUTLASS_CHECK(reference_gemm.initialize(reference_arguments, reference_workspace.get())); + CUTLASS_CHECK(reference_gemm.run()); + + using ElementBarrier = typename DistGemm::ElementBarrier; + using ElementFlag = typename DistGemmKernel::ElementFlag; + + // Set up per-device streams + cudaStream_t stream_arr[TP_]; + + for (int device_idx = 0; device_idx < TP_; ++device_idx) { + CUDA_CHECK(cudaSetDevice(device_idx)); + + // Create stream + CUDA_CHECK(cudaStreamCreate(&stream_arr[device_idx])); + } + + // Instantiate DistGEMM + DistGemm dist_gemm_arr[TP_]; // Distributed GEMM array for multiple devices + + // Allocate workspace memory + cutlass::device_memory::allocation workspace_arr[TP_]; + cutlass::device_memory::allocation exclusive_workspace_arr[TP_]; + + // Cross-device workspace pointer array for gemm.initialize() + void * workspace_ptr_arr[TP_]; + void * exclusive_workspace_ptr_arr[TP_]; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + DistGemmArguments arguments_[TP_]; + + for (int device_idx = 0; device_idx < TP_; ++device_idx) { + CUDA_CHECK(cudaSetDevice(device_idx)); + + arguments_[device_idx] = dist_gemm_args_from_options(options, device_idx, stream_arr[device_idx]); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = DistGemm::get_workspace_size(arguments_[device_idx]); + size_t exclusive_workspace_size = DistGemm::get_exclusive_workspace_size(); + + workspace_arr[device_idx] = cutlass::device_memory::allocation(workspace_size); + exclusive_workspace_arr[device_idx] = cutlass::device_memory::allocation(exclusive_workspace_size); + + // Throw workspace pointers into arrays for gemm.initialize() + workspace_ptr_arr[device_idx] = workspace_arr[device_idx].get(); + exclusive_workspace_ptr_arr[device_idx] = exclusive_workspace_arr[device_idx].get(); + + // Zero out exclusive workspace + cudaMemsetAsync(exclusive_workspace_ptr_arr[device_idx], 0, exclusive_workspace_size, stream_arr[device_idx]); + + cudaDeviceSynchronize(); + } + + for (int device_idx = 0; device_idx < TP_; ++device_idx) { + CUDA_CHECK(cudaSetDevice(device_idx)); + + // Check if the problem size is supported or not + CUTLASS_CHECK(dist_gemm_arr[device_idx].can_implement(arguments_[device_idx])); + +#if defined(CUTLASS_ENABLE_GDC_FOR_SM100) + bool launch_with_pdl = true; +#else + bool launch_with_pdl = false; +#endif + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(dist_gemm_arr[device_idx].initialize( + arguments_, + workspace_ptr_arr, + exclusive_workspace_ptr_arr, + device_idx, + stream_arr[device_idx], + launch_with_pdl + )); + + cudaDeviceSynchronize(); + } + + // Correctness / Warmup iteration + std::cout << std::endl << " running DistGEMM..." << std::endl; + + for (int device_idx = 0; device_idx < TP_; ++device_idx) { + CUDA_CHECK(cudaSetDevice(device_idx)); + CUTLASS_CHECK(dist_gemm_arr[device_idx].run(stream_arr[device_idx])); + } + for (int device_idx = 0; device_idx < TP_; ++device_idx) { + CUDA_CHECK(cudaStreamSynchronize(stream_arr[device_idx])); + CUDA_CHECK(cudaGetLastError()); + gather_results(options, device_idx); + } + + std::cout << " running DistGEMM finished without runtime errors" << std::endl; + + //// Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + + result.passed = verify(options); + + std::cout << std::endl << " Disposition (eps: " << options.eps << "): " << + (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) { + float elapsed_ms = 0.f; + + // Warmup + std::cout << " Warming up for " << options.warmup_iterations << " iterations." << std::endl; + for (int warmup_iter = 0; warmup_iter < options.warmup_iterations; ++warmup_iter) { + for (int device_idx = 0; device_idx < TP_; ++device_idx) { + CUDA_CHECK(cudaSetDevice(device_idx)); + CUTLASS_CHECK(dist_gemm_arr[device_idx].run(stream_arr[device_idx])); + } + } + + for (int device_idx = 0; device_idx < TP_; ++device_idx) { + CUDA_CHECK(cudaSetDevice(device_idx)); + CUDA_CHECK(cudaStreamSynchronize(stream_arr[device_idx])); + } + + CUDA_CHECK(cudaSetDevice(primary_device_idx)); + + // Benchmark + std::cout << " Profiling for " << options.iterations << " iterations." << std::endl; + using AtomicBoolean = cuda::atomic; + AtomicBoolean* atomic_flag_ptr; + CUDA_CHECK(cudaHostAlloc(&atomic_flag_ptr, sizeof(AtomicBoolean), cudaHostAllocPortable)); + atomic_flag_ptr->store(false); + + cutlass::DistGpuTimer timer; + + for (int device_idx = 0; device_idx < TP_; ++device_idx) { + CUDA_CHECK(cudaSetDevice(device_idx)); + cutlass::delay_kernel<<<1, 1, 0, stream_arr[device_idx]>>>(atomic_flag_ptr); + CUDA_CHECK(cudaGetLastError()); + } + + for (int device_idx = 0; device_idx < TP_; ++device_idx) { + timer.start(device_idx, stream_arr[device_idx]); + } + + atomic_flag_ptr->store(true); + + for (int profile_iter = 0; profile_iter < options.iterations; ++profile_iter) { + for (int device_idx = 0; device_idx < TP_; ++device_idx) { + CUDA_CHECK(cudaSetDevice(device_idx)); + CUTLASS_CHECK(dist_gemm_arr[device_idx].run(stream_arr[device_idx])); + } + } + + for (int device_idx = 0; device_idx < TP_; ++device_idx) { + CUDA_CHECK(cudaSetDevice(device_idx)); + timer.stop(device_idx, stream_arr[device_idx]); + } + + CUDA_CHECK(cudaSetDevice(primary_device_idx)); + + for (int device_idx = 0; device_idx < TP_; ++device_idx) { + elapsed_ms = max(elapsed_ms, timer.elapsed_millis(device_idx)); + } + + // Compute average runtime and TFLOPs. + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + double avg_runtime_s = (double)(result.avg_runtime_ms / 1000.0); + result.tflops = options.tflops(avg_runtime_s); + + auto [local_M, local_N, local_K, local_L] = DistSchedule::get_local_gemm_shape( + cute::make_tuple(options.m, options.n, options.k, options.l)); + + std::cout << std::endl; + std::cout << " TP: " << TP::value << std::endl; + std::cout << " Problem Size: " << + options.m << " x " << + options.n << " x " << + options.k << " x " << + options.l << std::endl; + std::cout << " Local GEMM Problem Size: " << + local_M << " x " << + local_N << " x " << + local_K << " x " << + local_L<< std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " TFLOPS: " << result.tflops << std::endl; + } + + return 0; +} + +#endif // (defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4)) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA Toolkit 12.4 or newer to run this example + // and must have compute capability at least 90. + // Some necessary cuda graph APIs were only introduced in CUDA 12.4. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) { + std::cerr << "This example requires CUDA 12.4 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + int num_devices; + CUDA_CHECK(cudaGetDeviceCount(&num_devices)); + if (num_devices < TP_) { + std::cerr << "Distributed GEMM is compiled with TP = " << TP::value << ", but " << + "found only " << num_devices << " devices." << + std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 10 || props.minor != 0) { + std::cerr + << "This example requires a GPU of NVIDIA's Blackwell Architecture " + << "(compute capability 100), " + << "got compute capability " << props.major * 10 + props.minor << "." + << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if (defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4)) + run(options); +#endif + + return 0; +} diff --git a/examples/82_blackwell_distributed_gemm/CMakeLists.txt b/examples/82_blackwell_distributed_gemm/CMakeLists.txt new file mode 100644 index 0000000000..fa8fe9adee --- /dev/null +++ b/examples/82_blackwell_distributed_gemm/CMakeLists.txt @@ -0,0 +1,32 @@ +# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + 82_blackwell_distributed_gemm + 82_blackwell_distributed_gemm.cu + ) diff --git a/examples/82_blackwell_distributed_gemm/README.md b/examples/82_blackwell_distributed_gemm/README.md new file mode 100644 index 0000000000..6f6c19b867 --- /dev/null +++ b/examples/82_blackwell_distributed_gemm/README.md @@ -0,0 +1,37 @@ +# Blackwell Distributed GEMM + +This example implements Tensor Parallel GEMMs for the Hopper architecture with the experimental +[Distributed GEMM](../../include/cutlass/experimental/distributed) API in CUTLASS. + +This example requires Blackwell GPUs with an any-to-any NVLink network. +Please refer to [REQUIREMENTS.md](REQUIREMENTS.md) for more information. + +By default, the example assumes 8 GPUs (TP=8) and runs an All Gather + GEMM operation, which rotates +operand A. To run with a different number of GPUs or schedule, please refer to +[82_blackwell_distributed_gemm.cu](82_blackwell_distributed_gemm.cu). + + +## Getting started + +Command line arguments are mostly similar to other examples: + +``` +--m= Sets the M extent of the GEMM +--n= Sets the N extent of the GEMM +--k= Sets the K extent of the GEMM +--l= Sets the L extent (batch) of the GEMM (default: 1) +--alpha= Epilogue scalar alpha (default: 1.0) +--beta= Epilogue scalar beta (default: 0.0) +--iterations= Number of profiling iterations to perform (default: 100) +--warmup-iterations= Number of warmup iterations prior to profiling (default: 10) +--eps= Threshold for error compared to reference GEMM (default: 0.0) +``` + +Sample run command: + +```bash +./82_blackwell_distributed_gemm --m=16384 --n=106496 --k=16384 --warmup-iterations=10 --iterations=100 +``` + +This example follows the [Hopper example](../65_distributed_gemm/) very closely, and only differs in the base GEMM kernel. For +more information you can refer to [that example](../65_distributed_gemm/README.md). diff --git a/examples/82_blackwell_distributed_gemm/REQUIREMENTS.md b/examples/82_blackwell_distributed_gemm/REQUIREMENTS.md new file mode 100644 index 0000000000..3943716b2c --- /dev/null +++ b/examples/82_blackwell_distributed_gemm/REQUIREMENTS.md @@ -0,0 +1,86 @@ +# Blackwell Distributed GEMM + +## Requirements + +### Build +Make sure to set up CUTLASS with +support for [Programmatic Dependent Launch (PDL)](../../media/docs/dependent_kernel_launch.md), +that is with the `CUTLASS_ENABLE_GDC_FOR_SM100` flag. + +```bash +cmake $PATH -DCUTLASS_NVCC_ARCHS="100a" -DCUTLASS_ENABLE_GDC_FOR_SM100=1 +``` + +### Minimum software + +Like all other CUTLASS examples, the NVIDIA driver, runtime, and CUDA Toolkit are required. +This example specifically requires CUDA Toolkit 12.6 or newer, due to some of the necessary +CUDA graph APIs. + +### Hardware / driver settings + +This example requires Blackwell GPUs with NVLink network. + +If you're not sure, first run the following command and make sure your GPU +compute capability is 10.0: + +```bash +nvidia-smi --query-gpu=name,compute_cap --format=csv +``` + +Sample output: + +``` +name, compute_cap +NVIDIA B200, 10.0 +NVIDIA B200, 10.0 +NVIDIA B200, 10.0 +NVIDIA B200, 10.0 +NVIDIA B200, 10.0 +NVIDIA B200, 10.0 +NVIDIA B200, 10.0 +NVIDIA B200, 10.0 +``` + + +Then you should make sure there is an NVLink network by checking the GPU network topology, +and making sure there's `NV*` links between every pair of GPUs: + +```bash +nvidia-smi topo -m +``` + +Sample output: + +``` + GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 +GPU0 X NV18 NV18 NV18 NV18 NV18 NV18 NV18 +GPU1 NV18 X NV18 NV18 NV18 NV18 NV18 NV18 +GPU2 NV18 NV18 X NV18 NV18 NV18 NV18 NV18 +GPU3 NV18 NV18 NV18 X NV18 NV18 NV18 NV18 +GPU4 NV18 NV18 NV18 NV18 X NV18 NV18 NV18 +GPU5 NV18 NV18 NV18 NV18 NV18 X NV18 NV18 +GPU6 NV18 NV18 NV18 NV18 NV18 NV18 X NV18 +GPU7 NV18 NV18 NV18 NV18 NV18 NV18 NV18 X +``` + +Finally, check if the driver enables peer to peer access, which should usually be the case, +but it's good to check anyway: + +```bash +nvidia-smi topo -p2p r +``` + +Sample output: + +``` + GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 +GPU0 X OK OK OK OK OK OK OK +GPU1 OK X OK OK OK OK OK OK +GPU2 OK OK X OK OK OK OK OK +GPU3 OK OK OK X OK OK OK OK +GPU4 OK OK OK OK X OK OK OK +GPU5 OK OK OK OK OK X OK OK +GPU6 OK OK OK OK OK OK X OK +GPU7 OK OK OK OK OK OK OK X +``` diff --git a/examples/83_blackwell_sparse_gemm/83_blackwell_sparse_gemm.cu b/examples/83_blackwell_sparse_gemm/83_blackwell_sparse_gemm.cu new file mode 100644 index 0000000000..d428047219 --- /dev/null +++ b/examples/83_blackwell_sparse_gemm/83_blackwell_sparse_gemm.cu @@ -0,0 +1,607 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief A FP16 sparse GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS. + + The Blackwell SM100 CUTLASS kernel uses of the following Blackwell SM100 features: + + 1. New series of Tensor Core MMA Instructions (tcgen05) introduced on the Blackwell architecture (sm100a) + which have 2x throughput compared to Hopper Tensor Core MMA instructions (WGMMA). + + Note that Hopper WGMMA Tensor Core MMA instructions are not compatible on Blackwell (See https://docs.nvidia.com/cuda/parallel-thread-execution). + + 2. A new per-SM memory called Tensor Memory (TMEM) introduced on the Blackwell architecture (sm100a). + Blackwell SM100 Tensor Core MMA instructions store their accumulation results in TMEM instead of the + Register File. (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/). + + 3. An extended flavor of the warp-specialized kernel design introduced in Hopper enabled by use of TMEM + which allows us to decouple the execution of MMA and epilogue into separate warps. + + 4. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution). + + Usage: + $ ./examples/83_blackwell_sparse_gemm/83_blackwell_sparse_gemm --m=8192 --n=8192 --k=8192 +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" +#include "cutlass/transform/device/transform_universal_adapter.hpp" + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = half_t; // Element type for A matrix operand +using LayoutTagA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 2 * 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes), 2x for compress along k + +// E matrix config +using ElementE = cute::uint8_t; + +// B matrix configuration +using ElementB = half_t; // Element type for B matrix operand +using LayoutTagB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementD = float; // Element type for D matrix operand +using ElementC = float; // Element type for C matrix operand +using LayoutTagC = cutlass::layout::ColumnMajor; // Layout type for C matrix operand +using LayoutTagD = cutlass::layout::ColumnMajor; // Layout type for D matrix operand +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Kernel functional config +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassSparseTensorOp; // Operator class tag + +// MMA and Cluster Tile Shapes +// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0 +using MmaTileShape_MNK = Shape<_256,_128,_64>; +// Shape of the threadblocks in a cluster +using ClusterShape_MNK = Shape<_2,_1,_1>; + +// Build the epilogue +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutTagC, AlignmentC, + ElementD, LayoutTagD, AlignmentD, + cutlass::epilogue::TmaWarpSpecialized2Sm + >::CollectiveOp; + +// Build the mainloop +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutTagA, AlignmentA, + ElementB, LayoutTagB, AlignmentB, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveoutEpi, + cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100 + >::CollectiveOp; + +using ProblemShape = Shape; + +// Compose into a kernel +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + void>; // Default to ClusterLaunchControl (CLC) based tile scheduler + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Reference device GEMM implementation type +using DeviceGemmReference = cutlass::reference::device::Gemm< + ElementA, + LayoutTagA, + ElementB, + LayoutTagB, + ElementC, + LayoutTagC, + ElementAccumulator, + ElementAccumulator>; + +// Layouts +using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA; +using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE; +using StrideA = cutlass::gemm::TagToStrideA_t; +using StrideE = StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +// +// Compressor +// +using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig; + +using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + ProblemShape, + ElementA, + LayoutTagA, + SparseConfig>; + +using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + ProblemShape, + ElementA, + LayoutTagA, + SparseConfig, + ArchTag>; + +using Compressor = cutlass::transform::device::TransformUniversalAdapter; + +// +// Data members +// + +/// Initialization +LayoutA layout_A; +LayoutE layout_E; +StrideA stride_A; +StrideA stride_A_compressed; +StrideE stride_E; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; + +uint64_t seed; + +ProblemShape problem_shape; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_A_compressed; +cutlass::DeviceAllocation block_E; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int m, n, k, l; + + Options(): + help(false), + m(8192), n(8192), k(8192), l(1), + alpha(1.f), beta(0.f), + iterations(10) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "83_blackwell_sparse_gemm\n\n" + << " Blackwell FP16 Sparse GEMM example.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "83_blackwell_sparse_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + constexpr int bits_input = cutlass::sizeof_bits::value; + + if constexpr (bits_input == 1) { + scope_max = Element(2); + scope_min = Element(0); + } + else if constexpr (bits_input <= 8) { + scope_max = Element(2); + scope_min = Element(-2); + } + else { + scope_max = Element(8); + scope_min = Element(-8); + } + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + return true; +} + +/// Make A structured sparse by replacing elements with 0 and compress it +bool sparsify_and_compress() +{ + auto [M, N, K, L] = problem_shape; + CompressorUtility compressor_utility(problem_shape, stride_A); + + // TensorE + // In unit of ElementE (uint8_t), after alignment requirement + // M-dim: TensorEAtom_M alignment + // K-dim: TensorEAtom_K alignment + int KAlignedE = compressor_utility.get_metadata_k_physical(); + int MAlignedE = compressor_utility.get_metadata_m_physical(); + + // TensorA Compressed + // In unit of ElementARaw, after alignment requirement + // M-dim: TMA alignment + // K-dim: TMA alignment + int KAlignedAC = compressor_utility.get_tensorA_k_physical(); + int MAlignedAC = compressor_utility.get_tensorA_m_physical(); + + block_A_compressed.reset(M * KAlignedAC * L); + block_E.reset(MAlignedE * KAlignedE * L); + + stride_A_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, KAlignedAC, L)); + stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(MAlignedE, KAlignedE, L)); + + // Random 50% fill zero is performed on host + std::vector block_A_host(block_A.size()); + cutlass::device_memory::copy_to_host(block_A_host.data(), block_A.get(), block_A.size()); + compressor_utility.structure_sparse_zero_mask_fill(block_A_host.data(), static_cast(seed + 2024)); + cutlass::device_memory::copy_to_device(block_A.get(), block_A_host.data(), block_A.size()); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments { + problem_shape, + { block_A.get(), + stride_A, + block_A_compressed.get(), + block_E.get() }, + {hw_info} }; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status {cutlass::Status::kSuccess }; + status = compressor_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.run(); + if (status != cutlass::Status::kSuccess) { + return false; + } + + auto result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + return false; + } + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +bool initialize(const Options &options) { + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1}); + + block_A.reset(options.m * options.k); + block_B.reset(options.k * options.n); + block_C.reset(options.m * options.n); + block_D.reset(options.m * options.n); + block_ref_D.reset(options.m * options.n); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + + // Compress row A and get A_compress and E + problem_shape = make_tuple(options.m, options.n, options.k, options.l); + if (not sparsify_and_compress()) { + return false; + }; + + // Build the compressed/metadata layouts + layout_A = SparseConfig::fill_layoutA(problem_shape); + layout_E = SparseConfig::fill_layoutE(problem_shape); + + return true; +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + problem_shape, + { block_A_compressed.get(), layout_A, block_B.get(), stride_B, block_E.get(), layout_E }, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} + }; + + return arguments; +} + +bool verify(const Options &options) { + cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({options.m, options.k})); + cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({options.k, options.n})); + cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({options.m, options.n})); + cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({options.m, options.n})); + + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + DeviceGemmReference gemm_reference; + + // Launch device reference gemm kernel + gemm_reference( + {options.m, options.n, options.k}, + ElementAccumulator(options.alpha), + ref_A, + ref_B, + ElementAccumulator(options.beta), + ref_C, + ref_D); + + // Wait for kernel to finish + CUDA_CHECK(cudaDeviceSynchronize()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + auto init_pass = initialize(options); + if (not init_pass) { + std::cout << "Initialization failure" << std::endl; + exit(EXIT_FAILURE); + } + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + cudaDeviceSynchronize(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (not result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example + // and must have compute capability at least 100. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (not (props.major == 10 && props.minor == 0)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/83_blackwell_sparse_gemm/CMakeLists.txt b/examples/83_blackwell_sparse_gemm/CMakeLists.txt new file mode 100644 index 0000000000..765ef4c4ad --- /dev/null +++ b/examples/83_blackwell_sparse_gemm/CMakeLists.txt @@ -0,0 +1,38 @@ + +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +if (CUTLASS_NVCC_ARCHS MATCHES 100a) + +cutlass_example_add_executable( + 83_blackwell_sparse_gemm + 83_blackwell_sparse_gemm.cu +) + +endif() diff --git a/examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm.cu b/examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm.cu new file mode 100644 index 0000000000..d2d87c4697 --- /dev/null +++ b/examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm.cu @@ -0,0 +1,693 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief A Narrow Precision Sparse GEMM example using CUTLASS for the NVIDIA Blackwell SM100 architecture. + + This example demonstrates a simple way to instantiate and run a blockscaled NVFP4 Sparse GEMM on the NVIDIA Blackwell SM100 architecture. + + The Blackwell SM100 CUTLASS kernel uses the new Block Scaled Tensor Core MMA Instructions (tcgen05.mma.blockscaled) introduced + on the Blackwell architecture (sm100a) which have 2x throughput compared to fp8 Tensor Core MMA instructions (tcgen05.mma) + and 4x throughput compared to fp8 Hopper Tensor Core MMA Instructions (WGMMA) (See https://docs.nvidia.com/cuda/parallel-thread-execution). + + Similar to 83_blackwell_sparse_gemm, this kernel leverages: + 1. Per-SM memory called Tensor Memory (TMEM) (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/). + + 2. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM + which allows us to decouple the execution of MMA and epilogue into separate warps. + + 3. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution). + + Usage: + $ ./examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm --m=2048 --n=2048 --k=2048 +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" +#include "cutlass/transform/device/transform_universal_adapter.hpp" + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::float_e2m1_t; +using ElementAPair = cutlass::nv_float4_t; // Element type for A matrix operand +using LayoutTagA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 64; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes), 2x for compress along k + +// E matrix config +using ElementE = cute::uint8_t; +using LayoutTagE = LayoutTagA; + +// B matrix configuration +using ElementB = cutlass::float_e2m1_t; +using ElementBPair = cutlass::nv_float4_t; // Element type for B matrix operand +using LayoutTagB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// SF +using ElementSF = typename ElementAPair::ScaleFactorType; + +// C/D matrix configuration +using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand +using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand +using LayoutTagC = cutlass::layout::RowMajor; // Layout type for C matrix operand +using LayoutTagD = cutlass::layout::RowMajor; // Layout type for D matrix operand +constexpr int AlignmentD = (16 * 8) / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +constexpr int AlignmentC = (16 * 8) / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Kernel functional config +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassBlockScaledSparseTensorOp; // Operator class tag + +// MMA and Cluster Tile Shapes +// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0 +using MmaTileShape = Shape<_256,_128,_256>; +// Shape of the threadblocks in a cluster +using ClusterShape = Shape<_2,_1,_1>; + +// Build the epilogue +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutTagC, AlignmentC, + ElementD, LayoutTagD, AlignmentD, + cutlass::epilogue::TmaWarpSpecialized2SmNvf4 + >::CollectiveOp; + +// Build the mainloop +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementAPair, LayoutTagA, AlignmentA, + ElementBPair, LayoutTagB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveoutEpi, + cutlass::gemm::KernelSparseTmaWarpSpecialized2SmNvf4Sm100 + >::CollectiveOp; + +using ProblemShape = Shape; + +// Compose into a kernel +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + void>; // Default to ClusterLaunchControl (CLC) based tile scheduler + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// +// Blockscale +// +using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; +using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; +using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF; +using SfAtom = typename Sm1xxBlkScaledConfig::SfAtom; + +using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA; +using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE; +using StrideA = cutlass::gemm::TagToStrideA_t; +using StrideE = StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; + +using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. +using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. + +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +// +// Compressor +// +using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig; + +using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + ProblemShape, + ElementA, + LayoutTagA, + SparseConfig>; + +using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + ProblemShape, + ElementA, + LayoutTagA, + SparseConfig, + ArchTag>; + +using Compressor = cutlass::transform::device::TransformUniversalAdapter; + +// +// Data members +// + +/// Initialization +StrideA stride_A; +StrideA stride_A_compressed; +StrideE stride_E; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; + +LayoutA layout_A; +LayoutE layout_E; +LayoutSFA layout_SFA; +LayoutSFB layout_SFB; + +typename LayoutTagA::Stride stride_factor_A; +typename LayoutTagB::Stride stride_factor_B; +typename LayoutTagE::Stride stride_factor_E; +typename LayoutTagC::Stride stride_factor_C; +typename LayoutTagD::Stride stride_factor_D; + +uint64_t seed; + +ProblemShape problem_shape; + +// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device +// Use cute::Tensor and cute::Layout for iterating thru the matrix elements +cutlass::HostTensor tensor_A; +cutlass::HostTensor tensor_A_compressed; +cutlass::HostTensor tensor_E; +cutlass::HostTensor tensor_B; +cutlass::HostTensor tensor_C; +cutlass::HostTensor tensor_SFA; +cutlass::HostTensor tensor_SFB; +cutlass::HostTensor tensor_D; +cutlass::HostTensor reference_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +template +auto make_iterator(T* ptr) { + using namespace cute; + if constexpr (cute::is_subbyte_v) { + return subbyte_iterator(ptr); + } + else { + return ptr; + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int m, n, k, l; + + Options(): + help(false), + m(1024), n(1024), k(1024), l(1), + alpha(1.f), beta(0.f), + iterations(10) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "84a_blackwell_nvfp4_bf16_sparse_gemm\n\n" + << " Blackwell NVFP4 GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out << "\n\nExamples:\n\n" + << "$ " << "./examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +void initialize_tensor( + cutlass::TensorView view, + uint64_t seed) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + else if (bits_input <= 8) { + if constexpr (cute::is_same_v){ + scope_max = 4; + scope_min = 1; + } + else { + scope_max = 1; + scope_min = -1; + } + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); +} + +/// Initialize operands to be used in the GEMM and reference GEMM +bool initialize(const Options &options) { + + problem_shape = make_tuple(options.m, options.n, options.k, options.l); + + // * Get A B C D size + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1}); + layout_A = SparseConfig::fill_layoutA(problem_shape); + layout_E = SparseConfig::fill_layoutE(problem_shape); + layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape); + layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape); + + // * Get ACompress & E size + CompressorUtility compressor_utility(problem_shape, stride_A); + + // TensorE + // In unit of ElementE (uint8_t), after alignment requirement + // M-dim: TensorEAtom_M alignment + // K-dim: TensorEAtom_K alignment + int KAlignedE = compressor_utility.get_metadata_k_physical(); + int MAlignedE = compressor_utility.get_metadata_m_physical(); + + // TensorA Compressed + // In unit of ElementARaw, after alignment requirement + // M-dim: TMA alignment + // K-dim: TMA alignment + int KAlignedAC = compressor_utility.get_tensorA_k_physical(); + int MAlignedAC = compressor_utility.get_tensorA_m_physical(); + + stride_A_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, KAlignedAC, options.l)); + stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(MAlignedE, KAlignedE, options.l)); + + // * Get SFA & SFB size + auto k_blks = cutlass::ceil_div(options.k, cute::size<1>(shape(SfAtom{}))); + auto m_blks = cutlass::ceil_div(options.m, Blk_MN{}); + auto n_blks = cutlass::ceil_div(options.n, Blk_MN{}); + + // * Allocate Tensor + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); + auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); + auto e_coord = cutlass::make_Coord(MAlignedE * options.l, KAlignedE); + auto a_comp_coord = cutlass::make_Coord(MAlignedAC * options.l, KAlignedAC); + auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); + auto d_coord = cutlass::make_Coord(options.m * options.l, options.n); + auto sfa_coord = cutlass::make_Coord(m_blks * Blk_MN{} * options.l, k_blks * Blk_SF{}); + auto sfb_coord = cutlass::make_Coord(n_blks * Blk_MN{} * options.l, k_blks * Blk_SF{}); + + tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); + tensor_A_compressed.resize(a_comp_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_comp_coord, stride_factor_A)); + tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); + tensor_E.resize(e_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(e_coord, stride_factor_E)); + tensor_C.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C)); + tensor_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(d_coord, stride_factor_D)); + reference_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(d_coord, stride_factor_D), false); + tensor_SFA.resize(sfa_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfa_coord, stride_factor_A)); + tensor_SFB.resize(sfb_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfb_coord, stride_factor_B)); + + // * Random init + initialize_tensor(tensor_A.host_view(), seed + 2021); + initialize_tensor(tensor_B.host_view(), seed + 2022); + initialize_tensor(tensor_C.host_view(), seed + 2023); + initialize_tensor(tensor_SFA.host_view(), seed + 2024); + initialize_tensor(tensor_SFB.host_view(), seed + 2025); + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + // * Random fill 50% A with zero + compressor_utility.structure_sparse_zero_mask_fill(tensor_A.host_data(), static_cast(seed + 2023)); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_SFA.sync_device(); + tensor_SFB.sync_device(); + + // * Compress + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments{ + problem_shape, + {tensor_A.device_data(), + stride_A, + tensor_A_compressed.device_data(), + tensor_E.device_data()}, + {hw_info} + }; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status {cutlass::Status::kSuccess }; + status = compressor_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.run(); + if (status != cutlass::Status::kSuccess) { + return false; + } + + auto result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + return false; + } + + tensor_E.sync_host(); + tensor_A_compressed.sync_host(); + + return true; +} + +// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; + + typename Gemm::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, 1}, + { + reinterpret_cast(tensor_A_compressed.device_data()), layout_A, + reinterpret_cast(tensor_B.device_data()), stride_B, + tensor_E.device_data(), layout_E, + tensor_SFA.device_data(), layout_SFA, + tensor_SFB.device_data(), layout_SFB + }, + { + {options.alpha, options.beta}, + tensor_C.device_data(), stride_C, + tensor_D.device_data(), stride_D + } + }; + + return arguments; +} + +bool verify(const Options &options) { + using namespace cute; + + // Create the arguments for host reference implementation + auto A = make_tensor(make_iterator(tensor_A.host_data()), layout_A); + auto SFA = make_tensor(tensor_SFA.host_data(), layout_SFA); + auto B = make_tensor(make_iterator(tensor_B.host_data()), + make_layout(make_shape(options.n, options.k, options.l), stride_B)); + auto SFB = make_tensor(tensor_SFB.host_data(), layout_SFB); + + cutlass::reference::host::GettMainloopParams< + ElementAccumulator, + decltype(A), + decltype(B), + decltype(SFA), + decltype(SFB)> mainloop_params{A, SFA, B, SFB}; + + auto C = make_tensor(make_iterator(tensor_C.host_data()), + make_layout(make_shape(options.m, options.n, options.l), stride_C)); + auto D = make_tensor(make_iterator(reference_D.host_data()), + make_layout(make_shape(options.m, options.n, options.l), stride_D)); + + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementAccumulator, // ElementScalar + ElementAccumulator, // ElementAccumulator + ElementAccumulator, // ElementCompute + decltype(C), // TensorC + decltype(D) // TensorD + > epilogue_params{ + options.alpha, + options.beta, + C, + D}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // Comparison + tensor_D.sync_host(); + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); + passed &= (cutlass::reference::host::TensorNorm(reference_D.host_view()) > 0); + passed &= (cutlass::reference::host::TensorNorm(tensor_D.host_view()) > 0); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + auto init_pass = initialize(options); + if (not init_pass) { + std::cout << "Initialization failure" << std::endl; + exit(EXIT_FAILURE); + } + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + cudaDeviceSynchronize(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (not result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example + // and must have compute capability at least 100. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (not (props.major == 10 && props.minor == 0)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu b/examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu new file mode 100644 index 0000000000..a23af1581d --- /dev/null +++ b/examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu @@ -0,0 +1,695 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief A Narrow Precision Sparse GEMM example using CUTLASS for the NVIDIA Blackwell SM100 architecture. + + This example demonstrates a simple way to instantiate and run a blockscaled MXFP8 Sparse GEMM on the NVIDIA Blackwell SM100 architecture. + + The Blackwell SM100 CUTLASS kernel uses the new Block Scaled Tensor Core MMA Instructions (tcgen05.mma.blockscaled) introduced + on the Blackwell architecture (sm100a) which have 2x throughput compared to fp8 Tensor Core MMA instructions (tcgen05.mma) + and 4x throughput compared to fp8 Hopper Tensor Core MMA Instructions (WGMMA) (See https://docs.nvidia.com/cuda/parallel-thread-execution). + + Similar to 83_blackwell_sparse_gemm, this kernel leverages: + 1. Per-SM memory called Tensor Memory (TMEM) (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/). + + 2. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM + which allows us to decouple the execution of MMA and epilogue into separate warps. + + 3. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution). + + Usage: + $ ./examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm --m=2048 --n=2048 --k=2048 +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" +#include "cutlass/transform/device/transform_universal_adapter.hpp" + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::float_e4m3_t; +using ElementAPair = cutlass::mx_float8_t; // Element type for A matrix operand +using LayoutTagA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 64; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes), 2x for compress along k + +// E matrix config +using ElementE = cute::uint8_t; +using LayoutTagE = LayoutTagA; + +// B matrix configuration +using ElementB = cutlass::float_e2m1_t; +using ElementBPair = cutlass::mx_float4_t; // Element type for B matrix operand +using LayoutTagB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// SF +using ElementSF = typename ElementAPair::ScaleFactorType; + +// C/D matrix configuration +using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand +using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand +using LayoutTagC = cutlass::layout::RowMajor; // Layout type for C matrix operand +using LayoutTagD = cutlass::layout::RowMajor; // Layout type for D matrix operand +constexpr int AlignmentD = (16 * 8) / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +constexpr int AlignmentC = (16 * 8) / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Kernel functional config +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassBlockScaledSparseTensorOp; // Operator class tag + +// MMA and Cluster Tile Shapes +// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0 +using MmaTileShape_MNK = Shape<_256,_128,_256>; +// Shape of the threadblocks in a cluster +using ClusterShape_MNK = Shape<_2,_1,_1>; + +// Build the epilogue +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutTagC, AlignmentC, + ElementD, LayoutTagD, AlignmentD, + cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4 + >::CollectiveOp; + +// Build the mainloop +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementAPair, LayoutTagA, AlignmentA, + ElementBPair, LayoutTagB, AlignmentB, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveoutEpi, + cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100 + >::CollectiveOp; + +using ProblemShape = Shape; + +// Compose into a kernel +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + void>; // Default to ClusterLaunchControl (CLC) based tile scheduler + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// +// Blockscale +// +using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; +using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; +using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF; +using SfAtom = typename Sm1xxBlkScaledConfig::SfAtom; + +using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA; +using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE; +using StrideA = cutlass::gemm::TagToStrideA_t; +using StrideE = StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; + +using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. +using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. + +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +// +// Compressor +// +using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig; + +using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + ProblemShape, + ElementA, + LayoutTagA, + SparseConfig>; + +using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + ProblemShape, + ElementA, + LayoutTagA, + SparseConfig, + ArchTag>; + +using Compressor = cutlass::transform::device::TransformUniversalAdapter; + +// +// Data members +// + +/// Initialization +StrideA stride_A; +StrideA stride_A_compressed; +StrideE stride_E; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; + +LayoutA layout_A; +LayoutE layout_E; +LayoutSFA layout_SFA; +LayoutSFB layout_SFB; + +typename LayoutTagA::Stride stride_factor_A; +typename LayoutTagB::Stride stride_factor_B; +typename LayoutTagE::Stride stride_factor_E; +typename LayoutTagC::Stride stride_factor_C; +typename LayoutTagD::Stride stride_factor_D; + +uint64_t seed; + +ProblemShape problem_shape; + +// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device +// Use cute::Tensor and cute::Layout for iterating thru the matrix elements +cutlass::HostTensor tensor_A; +cutlass::HostTensor tensor_A_compressed; +cutlass::HostTensor tensor_E; +cutlass::HostTensor tensor_B; +cutlass::HostTensor tensor_C; +cutlass::HostTensor tensor_SFA; +cutlass::HostTensor tensor_SFB; +cutlass::HostTensor tensor_D; +cutlass::HostTensor reference_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +template +auto make_iterator(T* ptr) { + using namespace cute; + if constexpr (cute::is_subbyte_v) { + return subbyte_iterator(ptr); + } + else { + return ptr; + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int m, n, k, l; + + Options(): + help(false), + m(1024), n(1024), k(1024), l(1), + alpha(1.f), beta(0.f), + iterations(10) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "84b_blackwell_mixed_mxfp8_bf16_sparse_gemm\n\n" + << " Blackwell NVFP4 GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out << "\n\nExamples:\n\n" + << "$ " << "./examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +void initialize_tensor( + cutlass::TensorView view, + uint64_t seed) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + else if (bits_input <= 8) { + if constexpr (cute::is_same_v){ + scope_max = 4; + scope_min = 1; + } + else { + scope_max = 1; + scope_min = -1; + } + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); +} + +/// Initialize operands to be used in the GEMM and reference GEMM +bool initialize(const Options &options) { + + problem_shape = make_tuple(options.m, options.n, options.k, options.l); + + // * Get A B C D size + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1}); + layout_A = SparseConfig::fill_layoutA(problem_shape); + layout_E = SparseConfig::fill_layoutE(problem_shape); + layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape); + layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape); + + // * Get ACompress & E size + CompressorUtility compressor_utility(problem_shape, stride_A); + + // TensorE + // In unit of ElementE (uint8_t), after alignment requirement + // M-dim: TensorEAtom_M alignment + // K-dim: TensorEAtom_K alignment + int KAlignedE = compressor_utility.get_metadata_k_physical(); + int MAlignedE = compressor_utility.get_metadata_m_physical(); + + // TensorA Compressed + // In unit of ElementARaw, after alignment requirement + // M-dim: TMA alignment + // K-dim: TMA alignment + int KAlignedAC = compressor_utility.get_tensorA_k_physical(); + int MAlignedAC = compressor_utility.get_tensorA_m_physical(); + + stride_A_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, KAlignedAC, options.l)); + stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(MAlignedE, KAlignedE, options.l)); + + // * Get SFA & SFB size + auto k_blks = cutlass::ceil_div(options.k, cute::size<1>(shape(SfAtom{}))); + auto m_blks = cutlass::ceil_div(options.m, Blk_MN{}); + auto n_blks = cutlass::ceil_div(options.n, Blk_MN{}); + + // * Allocate Tensor + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); + auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); + auto e_coord = cutlass::make_Coord(MAlignedE * options.l, KAlignedE); + auto a_comp_coord = cutlass::make_Coord(MAlignedAC * options.l, KAlignedAC); + auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); + auto d_coord = cutlass::make_Coord(options.m * options.l, options.n); + auto sfa_coord = cutlass::make_Coord(m_blks * Blk_MN{} * options.l, k_blks * Blk_SF{}); + auto sfb_coord = cutlass::make_Coord(n_blks * Blk_MN{} * options.l, k_blks * Blk_SF{}); + + tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); + tensor_A_compressed.resize(a_comp_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_comp_coord, stride_factor_A)); + tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); + tensor_E.resize(e_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(e_coord, stride_factor_E)); + tensor_C.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C)); + tensor_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(d_coord, stride_factor_D)); + reference_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(d_coord, stride_factor_D), false); + tensor_SFA.resize(sfa_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfa_coord, stride_factor_A)); + tensor_SFB.resize(sfb_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfb_coord, stride_factor_B)); + + // * Random init + initialize_tensor(tensor_A.host_view(), seed + 2021); + initialize_tensor(tensor_B.host_view(), seed + 2022); + initialize_tensor(tensor_C.host_view(), seed + 2023); + initialize_tensor(tensor_SFA.host_view(), seed + 2024); + initialize_tensor(tensor_SFB.host_view(), seed + 2025); + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + // * Random fill 50% A with zero + compressor_utility.structure_sparse_zero_mask_fill(tensor_A.host_data(), static_cast(seed + 2023)); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_SFA.sync_device(); + tensor_SFB.sync_device(); + + // * Compress + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments{ + problem_shape, + {tensor_A.device_data(), + stride_A, + tensor_A_compressed.device_data(), + tensor_E.device_data()}, + {hw_info} + }; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status {cutlass::Status::kSuccess }; + status = compressor_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.run(); + if (status != cutlass::Status::kSuccess) { + return false; + } + + auto result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + return false; + } + + tensor_E.sync_host(); + tensor_A_compressed.sync_host(); + + return true; +} + +// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; + + typename Gemm::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, 1}, + { + reinterpret_cast(tensor_A_compressed.device_data()), layout_A, + reinterpret_cast(tensor_B.device_data()), stride_B, + tensor_E.device_data(), layout_E, + tensor_SFA.device_data(), layout_SFA, + tensor_SFB.device_data(), layout_SFB + }, + { + {options.alpha, options.beta}, + tensor_C.device_data(), stride_C, + tensor_D.device_data(), stride_D + } + }; + + return arguments; +} + +bool verify(const Options &options) { + using namespace cute; + + // Create the arguments for host reference implementation + auto A = make_tensor(make_iterator(tensor_A.host_data()), layout_A); + auto SFA = make_tensor(tensor_SFA.host_data(), layout_SFA); + auto B = make_tensor(make_iterator(tensor_B.host_data()), + make_layout(make_shape(options.n, options.k, options.l), stride_B)); + auto SFB = make_tensor(tensor_SFB.host_data(), layout_SFB); + + cutlass::reference::host::GettMainloopParams< + ElementAccumulator, + decltype(A), + decltype(B), + decltype(SFA), + decltype(SFB)> mainloop_params{A, SFA, B, SFB}; + + auto C = make_tensor(make_iterator(tensor_C.host_data()), + make_layout(make_shape(options.m, options.n, options.l), stride_C)); + auto D = make_tensor(make_iterator(reference_D.host_data()), + make_layout(make_shape(options.m, options.n, options.l), stride_D)); + + cutlass::reference::host::GettEpilogueParams< + ElementAccumulator, // ElementScalar + ElementAccumulator, // ElementScalingFactor + ElementAccumulator, // ElementAccumulator + ElementAccumulator, // ElementCompute + decltype(C), // TensorC + decltype(D) // TensorD + > epilogue_params{}; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = options.alpha; + epilogue_params.beta = options.beta; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // Comparison + tensor_D.sync_host(); + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); + passed &= (cutlass::reference::host::TensorNorm(reference_D.host_view()) > 0); + passed &= (cutlass::reference::host::TensorNorm(tensor_D.host_view()) > 0); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + auto init_pass = initialize(options); + if (not init_pass) { + std::cout << "Initialization failure" << std::endl; + exit(EXIT_FAILURE); + } + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + cudaDeviceSynchronize(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (not result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example + // and must have compute capability at least 100. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (not (props.major == 10 && props.minor == 0)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/84_blackwell_narrow_precision_sparse_gemm/CMakeLists.txt b/examples/84_blackwell_narrow_precision_sparse_gemm/CMakeLists.txt new file mode 100644 index 0000000000..751590b702 --- /dev/null +++ b/examples/84_blackwell_narrow_precision_sparse_gemm/CMakeLists.txt @@ -0,0 +1,41 @@ + +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +if (CUTLASS_NVCC_ARCHS MATCHES 100a) +cutlass_example_add_executable( + 84a_blackwell_nvfp4_bf16_sparse_gemm + 84a_blackwell_nvfp4_bf16_sparse_gemm.cu + ) + +cutlass_example_add_executable( + 84b_blackwell_mixed_mxfp8_bf16_sparse_gemm + 84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu + ) +endif() diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 84fc931118..f041869cc7 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -159,17 +159,21 @@ foreach(EXAMPLE 67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling 68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling 69_hopper_mixed_dtype_grouped_gemm - 70_blackwell_gemm - 71_blackwell_gemm_with_collective_builder - 72_blackwell_narrow_precision_gemm - 73_blackwell_gemm_preferred_cluster - 74_blackwell_gemm_streamk - 75_blackwell_grouped_gemm - 76_blackwell_conv - 77_blackwell_fmha - 78_blackwell_emulated_bf16x9_gemm + 70_blackwell_gemm + 71_blackwell_gemm_with_collective_builder + 72_blackwell_narrow_precision_gemm + 73_blackwell_gemm_preferred_cluster + 74_blackwell_gemm_streamk + 75_blackwell_grouped_gemm + 76_blackwell_conv + 77_blackwell_fmha + 78_blackwell_emulated_bf16x9_gemm 79_blackwell_geforce_gemm + 80_blackwell_geforce_sparse_gemm 81_blackwell_gemm_blockwise + 82_blackwell_distributed_gemm + 83_blackwell_sparse_gemm + 84_blackwell_narrow_precision_sparse_gemm ) add_subdirectory(${EXAMPLE}) endforeach() diff --git a/examples/README.md b/examples/README.md index 150115db1a..5bed6853d7 100644 --- a/examples/README.md +++ b/examples/README.md @@ -286,6 +286,18 @@ Blackwell SM120 MMA kernel targeting GeForce RTX 50 series CUDA Cores +* [80_blackwell_geforce_sparse_gemm](80_blackwell_geforce_sparse_gemm/) + + Blackwell SM120 sparse MMA kernel targeting GeForce RTX 50 series CUDA Cores + +* [83_blackwell_sparse_gemm](83_blackwell_sparse_gemm) + + Blackwell SM100 Sparse Gemm kernel + +* [84_blackwell_narrow_precision_sparse_gemm](84_blackwell_narrow_precision_sparse_gemm) + + Blackwell Block Scaled SM100 Sparse Gemm kernel + # CUTLASS SYCL - Programming Examples * [00_pvc_gemm](./sycl/00_pvc_gemm) diff --git a/examples/65_distributed_gemm/util/benchmark.h b/examples/common/dist_gemm_helpers.h similarity index 69% rename from examples/65_distributed_gemm/util/benchmark.h rename to examples/common/dist_gemm_helpers.h index 66a0dbb50d..ef258e6922 100644 --- a/examples/65_distributed_gemm/util/benchmark.h +++ b/examples/common/dist_gemm_helpers.h @@ -44,6 +44,11 @@ #include #include +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/cuda_host_adapter.hpp" + namespace cutlass { @@ -115,4 +120,46 @@ struct DistGpuTimer { } }; +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Generic device-to-device data movement kernel based for CuTe tensors. +/// +/// NOTE: this kernel assigns one element copy to every thread, and is by no means +/// an efficient way of copying tensors. It should only be used for convenience in +/// reference checks. +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +void device_copy(TensorSource tensor_source, + TensorDestination tensor_destination, + cudaStream_t stream); + + +template +__global__ void device_copy_kernel(TensorSource const tensor_source, + TensorDestination tensor_destination) { + auto linear_idx = blockIdx.x * blockDim.x + threadIdx.x; + using ElementSrc = typename TensorSource::value_type; + using ElementDst = typename TensorDestination::value_type; + NumericConverter converter; + if (linear_idx < size(tensor_source)) { + tensor_destination(linear_idx) = converter(tensor_source(linear_idx)); + } +} + +template +void device_copy(TensorSource tensor_source, + TensorDestination tensor_destination, + cudaStream_t stream) { + + assert(tensor_source.size() == tensor_destination.size()); + + auto numel = tensor_source.size(); + static constexpr int NumThreads = 128; + auto grid_size = cute::ceil_div(numel, NumThreads); + + dim3 grid(grid_size); + dim3 block(NumThreads); + device_copy_kernel<<>>(tensor_source, tensor_destination); +} + } //namespace cutlass diff --git a/examples/cute/tutorial/blackwell/01_mma_sm100.cu b/examples/cute/tutorial/blackwell/01_mma_sm100.cu index 3f73140a01..a11fb17c05 100644 --- a/examples/cute/tutorial/blackwell/01_mma_sm100.cu +++ b/examples/cute/tutorial/blackwell/01_mma_sm100.cu @@ -61,7 +61,8 @@ #include // CuTe tensor implementation #include // CuTe functions for querying the details of cluster launched #include // Compile time in constants such as _1, _256 etc. -#include +#include // Auto vectorized copy operation +#include // TMEM allocator for SM100 // Tutorial helpers #include "example_utils.hpp" @@ -122,7 +123,9 @@ struct SharedStorage alignas(128) cute::ArrayEngine> A; alignas(128) cute::ArrayEngine> B; - alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM + alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM + + alignas(16) cute::uint32_t tmem_base_ptr; // Base pointer for TMEM allocation CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(A.begin()), ASmemLayout{}); } CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(B.begin()), BSmemLayout{}); } @@ -225,6 +228,18 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) // ThrMma's make_fragment_C() creates a TMEM tensor with the appropriate layout for the accumulator. Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); // (MmaC, NumMma_M, NumMma_N) + uint32_t elect_one_thr = cute::elect_one_sync(); + uint32_t elect_one_warp = (threadIdx.x / 32 == 0); + + using TmemAllocator = cute::TMEM::Allocator1Sm; + TmemAllocator tmem_allocator{}; + + if (elect_one_warp) { + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + } + __syncthreads(); // Wait for all threads until warp0 allocates TMEM + tCtAcc.data() = shared_storage.tmem_base_ptr; + if (thread0()) { print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16) print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16) @@ -233,10 +248,8 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) print("tCtAcc:\t"); print(tCtAcc); print("\n"); // tCtAcc: tmem_[32b](TMEM_ADDR) o ((_128,_256),_1,_1):((_65536,_1),_0,_0) } __syncthreads(); - // Barrier Initialization - uint32_t elect_one_thr = cute::elect_one_sync(); - uint32_t elect_one_warp = (threadIdx.x / 32 == 0); + // Barrier Initialization // Barriers in SMEM initialized by a single thread. if (elect_one_warp && elect_one_thr) { cute::initialize_barrier(shared_storage.mma_barrier, /* num_ctas */ 1); @@ -306,6 +319,15 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) axpby(alpha, tDrAcc, beta, tDrC); // Store RMEM -> GMEM copy(tDrC, tDgD); + + __syncthreads(); + + // Release the right to allocate before deallocations so that the next CTA can rasterize + // Then deallocate TMEM + if (elect_one_warp) { + tmem_allocator.release_allocation_lock(); + tmem_allocator.free(shared_storage.tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } } template // CuTe tensor implementation #include // CuTe functions for querying the details of cluster launched #include // Compile time in constants such as _1, _256 etc. -#include +#include // Auto vectorized copy operation +#include // TMEM allocator for SM100 // Tutorial helpers #include "example_utils.hpp" @@ -124,6 +125,8 @@ struct SharedStorage alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM alignas(16) cute::uint64_t tma_barrier; // Barrier to track TMA data transfers to SMEM + alignas(16) cute::uint32_t tmem_base_ptr; // Base pointer for TMEM allocation + CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(A.begin()), ASmemLayout{}); } CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(B.begin()), BSmemLayout{}); } }; @@ -228,6 +231,18 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) // ThrMma's make_fragment_C() creates a TMEM tensor with the appropriate layout for the accumulator. Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); // (MmaC, NumMma_M, NumMma_N) + uint32_t elect_one_thr = cute::elect_one_sync(); + uint32_t elect_one_warp = (threadIdx.x / 32 == 0); + + using TmemAllocator = cute::TMEM::Allocator1Sm; + TmemAllocator tmem_allocator{}; + + if (elect_one_warp) { + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + } + __syncthreads(); // Wait for all threads until warp0 allocates TMEM + tCtAcc.data() = shared_storage.tmem_base_ptr; + if (thread0()) { print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16) print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16) @@ -269,9 +284,6 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) } __syncthreads(); // Barrier Initialization - uint32_t elect_one_thr = cute::elect_one_sync(); - uint32_t elect_one_warp = (threadIdx.x / 32 == 0); - // Barriers in SMEM initialized by a single thread. if (elect_one_warp && elect_one_thr) { cute::initialize_barrier(shared_storage.mma_barrier, /* num_ctas */ 1); @@ -346,6 +358,15 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) axpby(alpha, tDrAcc, beta, tDrC); // Store RMEM -> GMEM copy(tDrC, tDgD); + + __syncthreads(); + + // Release the right to allocate before deallocations so that the next CTA can rasterize + // Then deallocate TMEM + if (elect_one_warp) { + tmem_allocator.release_allocation_lock(); + tmem_allocator.free(shared_storage.tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } } template // CuTe tensor implementation #include // CuTe functions for querying the details of cluster launched #include // Compile time in constants such as _1, _256 etc. -#include +#include // Auto vectorized copy operation +#include // TMEM allocator for SM100 // Tutorial helpers #include "example_utils.hpp" @@ -129,6 +130,8 @@ struct SharedStorage alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM alignas(16) cute::uint64_t tma_barrier; // Barrier to track TMA data transfers to SMEM + alignas(16) cute::uint32_t tmem_base_ptr; // Base pointer for TMEM allocation + CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(A.begin()), ASmemLayout{}); } CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(B.begin()), BSmemLayout{}); } }; @@ -231,6 +234,18 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) // ThrMma's make_fragment_C() creates a TMEM tensor with the appropriate layout for the accumulator. Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); // (MmaC, NumMma_M, NumMma_N) + uint32_t elect_one_thr = cute::elect_one_sync(); + uint32_t elect_one_warp = (threadIdx.x / 32 == 0); + + using TmemAllocator = cute::TMEM::Allocator1Sm; + TmemAllocator tmem_allocator{}; + + if (elect_one_warp) { + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + } + __syncthreads(); // Wait for all threads until warp0 allocates TMEM + tCtAcc.data() = shared_storage.tmem_base_ptr; + if (thread0()) { print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16) print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16) @@ -305,10 +320,6 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) } __syncthreads(); // Barrier Initialization - - uint32_t elect_one_thr = cute::elect_one_sync(); - uint32_t elect_one_warp = (threadIdx.x / 32 == 0); - // Barriers in SMEM initialized by a single thread. if (elect_one_warp && elect_one_thr) { // The number of CTAs that participates in multicast operation with this CTA (for both A and B matrices) @@ -385,6 +396,15 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) axpby(alpha, tDrAcc, beta, tDrC); // Store RMEM -> GMEM copy(tDrC, tDgD); + + __syncthreads(); + + // Release the right to allocate before deallocations so that the next CTA can rasterize + // Then deallocate TMEM + if (elect_one_warp) { + tmem_allocator.release_allocation_lock(); + tmem_allocator.free(shared_storage.tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } } template // CuTe tensor implementation #include // CuTe functions for querying the details of cluster launched #include // Compile time in constants such as _1, _256 etc. -#include +#include // Auto vectorized copy operation +#include // TMEM allocator for SM100 // Tutorial helpers #include "example_utils.hpp" @@ -132,6 +133,8 @@ struct SharedStorage alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM alignas(16) cute::uint64_t tma_barrier; // Barrier to track TMA data transfers to SMEM + alignas(16) cute::uint32_t tmem_base_ptr; // Base pointer for TMEM allocation + CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(A.begin()), ASmemLayout{}); } CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(B.begin()), BSmemLayout{}); } }; @@ -234,6 +237,18 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) // ThrMma's make_fragment_C() creates a TMEM tensor with the appropriate layout for the accumulator. Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); // (MmaC, NumMma_M, NumMma_N) + uint32_t elect_one_thr = cute::elect_one_sync(); + uint32_t elect_one_warp = (threadIdx.x / 32 == 0); + + using TmemAllocator = cute::TMEM::Allocator2Sm; + TmemAllocator tmem_allocator{}; + + if (elect_one_warp) { + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + } + __syncthreads(); // Wait for all threads until warp0 allocates TMEM + tCtAcc.data() = shared_storage.tmem_base_ptr; + if (thread0()) { print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16) print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16) @@ -262,6 +277,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) // Construct the CTA-in-Cluster coordinate for multicasting auto cta_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(int(cute::block_rank_in_cluster())); + auto elect_one_cta = get<0>(cta_in_cluster_coord_vmnk) == Int<0>{}; // Project the cluster_layout for tma_A along the N-modes auto [tAgA, tAsA] = tma_partition(tma_atom_A, @@ -299,10 +315,6 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) } __syncthreads(); // Barrier Initialization - auto elect_one_thr = cute::elect_one_sync(); - auto elect_one_warp = (threadIdx.x / 32 == 0); - auto elect_one_cta = get<0>(cta_in_cluster_coord_vmnk) == Int<0>{}; - // Barriers in SMEM should be initialized by a single thread. if (elect_one_warp && elect_one_thr) { // The number of CTAs that participates in multicast operation with this CTA (for both A and B matrices) @@ -386,6 +398,15 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) axpby(alpha, tDrAcc, beta, tDrC); // Store RMEM -> GMEM copy(tDrC, tDgD); + + __syncthreads(); + + // Release the right to allocate before deallocations so that the next CTA can rasterize + // Then deallocate TMEM + if (elect_one_warp) { + tmem_allocator.release_allocation_lock(); + tmem_allocator.free(shared_storage.tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } } template // CuTe tensor implementation #include // CuTe functions for querying the details of cluster launched #include // Compile time in constants such as _1, _256 etc. -#include +#include // Auto vectorized copy operation +#include // TMEM allocator for SM100 // Tutorial helpers #include "example_utils.hpp" @@ -140,6 +141,8 @@ struct SharedStorage alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM alignas(16) cute::uint64_t tma_barrier; // Barrier to track TMA data transfers to SMEM + alignas(16) cute::uint32_t tmem_base_ptr; // Base pointer for TMEM allocation + CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(tensors.mainloop.A.begin()), ASmemLayout{}); } CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(tensors.mainloop.B.begin()), BSmemLayout{}); } CUTE_DEVICE constexpr auto tensor_sC() { return make_tensor(make_smem_ptr(tensors.C.begin()), CSmemLayout{}); } @@ -247,6 +250,18 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) // ThrMma's make_fragment_C() creates a TMEM tensor with the appropriate layout for the accumulator. Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); // (MmaC, NumMma_M, NumMma_N) + uint32_t elect_one_thr = cute::elect_one_sync(); + uint32_t elect_one_warp = (threadIdx.x / 32 == 0); + + using TmemAllocator = cute::TMEM::Allocator2Sm; + TmemAllocator tmem_allocator{}; + + if (elect_one_warp) { + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + } + __syncthreads(); // Wait for all threads until warp0 allocates TMEM + tCtAcc.data() = shared_storage.tmem_base_ptr; + if (thread0()) { print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16) print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16) @@ -275,6 +290,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) // Construct the CTA-in-Cluster coordinate for multicasting auto cta_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(int(cute::block_rank_in_cluster())); + auto elect_one_cta = get<0>(cta_in_cluster_coord_vmnk) == Int<0>{}; // Project the cluster_layout for tma_A along the N-modes auto [tAgA, tAsA] = tma_partition(tma_atom_A, @@ -312,10 +328,6 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) } __syncthreads(); // Barrier Initialization - auto elect_one_thr = cute::elect_one_sync(); - auto elect_one_warp = (threadIdx.x / 32 == 0); - auto elect_one_cta = get<0>(cta_in_cluster_coord_vmnk) == Int<0>{}; - // Barriers in SMEM should be initialized by a single thread. if (elect_one_warp && elect_one_thr) { // The number of CTAs that participates in multicast operation with this CTA (for both A and B matrices) @@ -441,6 +453,14 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) } __syncthreads(); // All threads sync with issuing thread } + __syncthreads(); + + // Release the right to allocate before deallocations so that the next CTA can rasterize + // Then deallocate TMEM + if (elect_one_warp) { + tmem_allocator.release_allocation_lock(); + tmem_allocator.free(shared_storage.tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } } template ::RasterOrderOptions; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerXeGroup::RasterOrderOptions; // Per-GEMM problem shape info may only exist on the device. if (host_problem_shapes_available) { diff --git a/include/cute/algorithm/cooperative_gemm.hpp b/include/cute/algorithm/cooperative_gemm.hpp index b1490b02db..f0a993593d 100644 --- a/include/cute/algorithm/cooperative_gemm.hpp +++ b/include/cute/algorithm/cooperative_gemm.hpp @@ -98,19 +98,23 @@ epilogue_predication(ThrMMA const& thr_mma, } } -template + class SmemCopyLdOpC, class SmemCopyStOpC> CUTE_HOST_DEVICE void -epilogue_no_predication(Alpha const& alpha, +epilogue_no_predication(uint32_t thread_idx, + ThrMMA const& thr_mma, + Alpha const& alpha, Tensor & tCrC, Beta const& beta, - Tensor & tCsC, + Tensor & sC, CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM CStoreTransformOp const& sC_store_op, // transforms results before they are stored to C - SmemCopyOpC const& sC_copy_op) + SmemCopyLdOpC const& sC_copy_ld_op, + SmemCopyStOpC const& sC_copy_st_op) { using InputTypeC = typename TSC::value_type; using ComputeTypeC = typename TRC::value_type; @@ -125,10 +129,18 @@ epilogue_no_predication(Alpha const& alpha, CUTE_GCC_UNREACHABLE; } (); - Tensor tCrDi = make_fragment_like(tCsC); Tensor tCrD = make_fragment_like(tCrC); + Tensor tCrDi = make_fragment_like(tCrD); + if(!isBetaZero) { - copy(sC_copy_op, tCsC, tCrDi); + auto smem_tiled_copy_C = make_tiled_copy_C(Copy_Atom{}, thr_mma); + auto smem_thr_copy_C = smem_tiled_copy_C.get_thread_slice(thread_idx); + Tensor tCsC = smem_thr_copy_C.partition_S(sC); + Tensor tCrDi_copy_view = smem_thr_copy_C.retile_D(tCrDi); + CUTE_STATIC_ASSERT_V(size<1>(tCsC) == size<1>(tCrDi_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsC) == size<2>(tCrDi_copy_view)); // CPY_N + copy(smem_tiled_copy_C, tCsC, tCrDi_copy_view); + // Transform C on/after load cute::transform(tCrDi, tCrD, sC_load_op); } @@ -136,7 +148,14 @@ epilogue_no_predication(Alpha const& alpha, axpby(alpha, tCrC, beta, tCrD); // Transform C before/on store cute::transform(tCrD, tCrDi, sC_store_op); - copy(sC_copy_op, tCrDi, tCsC); + + auto smem_tiled_copy_C = make_tiled_copy_C(Copy_Atom{}, thr_mma); + auto smem_thr_copy_C = smem_tiled_copy_C.get_thread_slice(thread_idx); + Tensor tCsC = smem_thr_copy_C.partition_D(sC); + Tensor tCrDi_copy_view = smem_thr_copy_C.retile_S(tCrDi); + CUTE_STATIC_ASSERT_V(size<1>(tCsC) == size<1>(tCrDi_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsC) == size<2>(tCrDi_copy_view)); // CPY_N + copy(smem_tiled_copy_C, tCrDi_copy_view, tCsC); } // Predicated Cooperative GEMM @@ -283,7 +302,9 @@ cooperative_gemm_no_predication(uint32_t thread_idx, // Create register tensors for the MMA to operate on Tensor tCrA = thr_mma.partition_fragment_A(sA); // (MMA,MMA_M,MMA_K) + Tensor tCrAi = make_fragment_like(tCrA); Tensor tCrB = thr_mma.partition_fragment_B(sB); // (MMA,MMA_N,MMA_K) + Tensor tCrBi = make_fragment_like(tCrB); using CopyOpAType = SmemCopyOpA; using CopyOpBType = SmemCopyOpB; @@ -291,7 +312,6 @@ cooperative_gemm_no_predication(uint32_t thread_idx, auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom{}, thr_mma); auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); Tensor tCsA = smem_thr_copy_A.partition_S(sA); - Tensor tCrAi = make_fragment_like(tCsA); Tensor tCrAi_copy_view = smem_thr_copy_A.retile_D(tCrAi); CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrAi_copy_view)); // CPY_M CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrAi_copy_view)); // CPY_K @@ -299,7 +319,6 @@ cooperative_gemm_no_predication(uint32_t thread_idx, auto smem_tiled_copy_B = make_tiled_copy_B(Copy_Atom{}, thr_mma); auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); Tensor tCsB = smem_thr_copy_B.partition_S(sB); - Tensor tCrBi = make_fragment_like(tCsB); Tensor tCrBi_copy_view = smem_thr_copy_B.retile_D(tCrBi); CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrBi_copy_view)); // CPY_N CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrBi_copy_view)); // CPY_K @@ -346,7 +365,7 @@ template + class SmemCopyLdOpC = DefaultCopy, class SmemCopyStOpC = DefaultCopy> CUTE_HOST_DEVICE void cooperative_gemm(uint32_t thread_idx, @@ -356,13 +375,14 @@ cooperative_gemm(uint32_t thread_idx, Tensor const& sB, Beta const& beta, Tensor & sC, - ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM - BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM - CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM - CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C - SmemCopyOpA const& sA_copy_op = {}, - SmemCopyOpB const& sB_copy_op = {}, - SmemCopyOpC const& sC_copy_op = {}) + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C + SmemCopyOpA const& sA_copy_op = {}, + SmemCopyOpB const& sB_copy_op = {}, + SmemCopyLdOpC const& sC_copy_ld_op = {}, + SmemCopyStOpC const& sC_copy_st_op = {}) { CUTE_STATIC_ASSERT_V(rank(sA) == Int<2>{}); CUTE_STATIC_ASSERT_V(rank(sB) == Int<2>{}); @@ -394,7 +414,7 @@ cooperative_gemm(uint32_t thread_idx, thread_idx, thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op, sA_copy_op, sB_copy_op ); detail::epilogue_no_predication( - alpha, tCrC, beta, tCsC, sC_load_op, sC_store_op, sC_copy_op + thread_idx, thr_mma,alpha, tCrC, beta, sC, sC_load_op, sC_store_op, sC_copy_ld_op, sC_copy_st_op ); } else { detail::cooperative_gemm_predication( @@ -466,7 +486,7 @@ template + class SmemCopyLdOpC = DefaultCopy, class SmemCopyStOpC = DefaultCopy> CUTE_HOST_DEVICE void cooperative_gemm(uint32_t thread_idx, @@ -476,17 +496,18 @@ cooperative_gemm(uint32_t thread_idx, Tensor const& sB, Beta const& beta, Tensor && sC, - ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM - BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM - CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM - CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C - SmemCopyOpA const& sA_copy_op = {}, - SmemCopyOpB const& sB_copy_op = {}, - SmemCopyOpC const& sC_copy_op = {}) + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C + SmemCopyOpA const& sA_copy_op = {}, + SmemCopyOpB const& sB_copy_op = {}, + SmemCopyLdOpC const& sC_copy_ld_op = {}, + SmemCopyStOpC const& sC_copy_st_op = {}) { cooperative_gemm(thread_idx, tiled_mma, alpha, sA, sB, beta, sC, sA_load_op, sB_load_op, sC_load_op, sC_store_op, - sA_copy_op, sB_copy_op, sC_copy_op); + sA_copy_op, sB_copy_op, sC_copy_ld_op, sC_copy_st_op); } // Legacy overload of cute::gemm for backwards-compatibility diff --git a/include/cute/algorithm/tuple_algorithms.hpp b/include/cute/algorithm/tuple_algorithms.hpp index 5055605315..cec86c4d6d 100644 --- a/include/cute/algorithm/tuple_algorithms.hpp +++ b/include/cute/algorithm/tuple_algorithms.hpp @@ -33,6 +33,7 @@ #include #include +#include #include #include #include @@ -283,34 +284,13 @@ transform_leaf(T0 const& t0, T1 const& t1, F&& f) // find and find_if // -namespace detail { - -template -CUTE_HOST_DEVICE constexpr -auto -find_if(T const& t, F&& f, seq) -{ - if constexpr (decltype(f(get(t)))::value) { - return cute::C{}; - } else - if constexpr (sizeof...(Is) == 0) { - return cute::C{}; - } else { - return find_if(t, f, seq{}); - } - - CUTE_GCC_UNREACHABLE; -} - -} // end namespace detail - template CUTE_HOST_DEVICE constexpr auto find_if(T const& t, F&& f) { if constexpr (is_tuple::value) { - return detail::find_if(t, f, tuple_seq{}); + return detail::tapply(t, f, [] (auto... a) { return cute::C>{}; }, tuple_seq{}); } else { return cute::C{}; } @@ -332,7 +312,7 @@ auto any_of(T const& t, F&& f) { if constexpr (is_tuple::value) { - return detail::apply(cute::transform(t, f), [&] (auto const&... a) { return (false_type{} || ... || a); }, tuple_seq{}); + return detail::tapply(t, f, [] (auto... a) { return (false_type{} || ... || a); }, tuple_seq{}); } else { return f(t); } @@ -346,7 +326,7 @@ auto all_of(T const& t, F&& f) { if constexpr (is_tuple::value) { - return detail::apply(cute::transform(t, f), [&] (auto const&... a) { return (true_type{} && ... && a); }, tuple_seq{}); + return detail::tapply(t, f, [] (auto... a) { return (true_type{} && ... && a); }, tuple_seq{}); } else { return f(t); } diff --git a/include/cute/arch/cluster_sm90.hpp b/include/cute/arch/cluster_sm90.hpp index ba22ef1ca5..524a47efb1 100644 --- a/include/cute/arch/cluster_sm90.hpp +++ b/include/cute/arch/cluster_sm90.hpp @@ -31,6 +31,7 @@ #pragma once #include +#include // Config #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && \ diff --git a/include/cute/arch/config.hpp b/include/cute/arch/config.hpp index 9158953886..2383b4e6c6 100644 --- a/include/cute/arch/config.hpp +++ b/include/cute/arch/config.hpp @@ -72,6 +72,27 @@ # define CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED #endif +#if (defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED)) +# define CUTE_ARCH_TMA_SM90_ENABLED +# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED +# define CUTE_ARCH_STSM_SM90_ENABLED +# define CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED +#endif + +#if defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) +# define CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM120F_ENABLED)) +# define CUTE_ARCH_TMA_SM90_ENABLED +# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED +# define CUTE_ARCH_STSM_SM90_ENABLED +#endif + #if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED)) # define CUTE_ARCH_TCGEN05_S8_MMA_ENABLED #endif @@ -91,8 +112,11 @@ #endif // {add, mul, fma}.f32x2 PTX -#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED)) - #define CUTE_ARCH_FLOAT2_MATH_ENABLED +#if defined(CUTLASS_ARCH_MMA_SM100_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) + // Enable CuTe MMA Atoms +# define CUTE_ARCH_FFMA2_SM100_ENABLED + // Enable f32x2 PTX generation +# define CUTE_ARCH_FLOAT2_MATH_ENABLED #endif #if defined(CUTLASS_ARCH_MMA_SM120_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) @@ -109,3 +133,37 @@ # endif #endif +#if defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) +# define CUTE_ARCH_LDSM_SM100A_ENABLED +# define CUTE_ARCH_STSM_SM100A_ENABLED +# define CUTE_ARCH_TCGEN05_TMEM_ENABLED +# define CUTE_ARCH_TMA_SM100_ENABLED +# define CUTE_ARCH_FLOAT2_MATH_ENABLED +#endif + +#if defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) +# define CUTE_ARCH_LDSM_SM100A_ENABLED +# define CUTE_ARCH_STSM_SM100A_ENABLED +# define CUTE_ARCH_TCGEN05_TMEM_ENABLED +# define CUTE_ARCH_TMA_SM100_ENABLED +#endif + +#if defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) +# define CUTE_ARCH_LDSM_SM100A_ENABLED +# define CUTE_ARCH_STSM_SM100A_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120F_ENABLED)) +# if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) +# define CUTE_ARCH_LOAD256_SM100A_ENABLED +# define CUTE_ARCH_STORE256_SM100A_ENABLED +# endif +#endif + +// {add, mul, fma}.f32x2 PTX +#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) + #define CUTE_ARCH_FLOAT2_MATH_ENABLED +#endif + diff --git a/include/cute/arch/copy_sm100.hpp b/include/cute/arch/copy_sm100.hpp index 19b13841a1..aa969afe9b 100644 --- a/include/cute/arch/copy_sm100.hpp +++ b/include/cute/arch/copy_sm100.hpp @@ -28,10 +28,6 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ - -// - -// #pragma once #include @@ -316,17 +312,14 @@ struct SM100_U8x16_STSM_T } }; -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cute - //////////////////////////////////////////////////////////////////////////////////////////////////// // // UTCCP PTX definitions // //////////////////////////////////////////////////////////////////////////////////////////////////// -namespace cute { +namespace SM100::TMEM::UTCCP { + // 128 data path lanes, 256-bit pattern, 1cta mode struct SM100_UTCCP_128dp256bit_1cta { @@ -558,21 +551,19 @@ struct SM100_UTCCP_2x64dp128bitlw0123_2cta } }; -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cute +} // end namespace SM100::TMEM::UTCCP //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// -namespace cute { +namespace SM100::TMEM::LOAD { //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// // -// TMEM_LOAD PTX definitions +// TMEM LOAD PTX definitions // //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -3945,7 +3936,6 @@ struct SM100_TMEM_LOAD_32dp32b128x } }; - //////////////////////////////////////////////////////////////////////////////////////////////////// // 32 data path lanes, 32-bit pattern, repeated 128 times, packed 16b read @@ -4065,9 +4055,21 @@ struct SM100_TMEM_LOAD_32dp32b128x_16b //////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM100::TMEM::LOAD + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace SM100::TMEM::STORE { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////////////////////////////// // -// TMEM_STORE PTX definitions +// TMEM STORE PTX definitions // //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -4086,8 +4088,8 @@ struct SM100_TMEM_STORE_16dp256b1x #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x256b.x1.b32" "[%0]," - "{%1, %2, %3, %4};\n" - : + "{%1, %2, %3, %4};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -4110,8 +4112,8 @@ struct SM100_TMEM_STORE_16dp256b1x_16b #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x256b.x1.unpack::16b.b32" "[%0]," - "{%1, %2, %3, %4};\n" - : + "{%1, %2, %3, %4};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -4136,8 +4138,8 @@ struct SM100_TMEM_STORE_16dp256b2x asm volatile ("tcgen05.st.sync.aligned.16x256b.x2.b32" "[%0]," "{%1, %2, %3, %4," - "%5, %6, %7, %8};\n" - : + "%5, %6, %7, %8};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); #else @@ -4163,8 +4165,8 @@ struct SM100_TMEM_STORE_16dp256b2x_16b asm volatile ("tcgen05.st.sync.aligned.16x256b.x2.unpack::16b.b32" "[%0]," "{%1, %2, %3, %4," - "%5, %6, %7, %8};\n" - : + "%5, %6, %7, %8};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); #else @@ -4194,8 +4196,8 @@ struct SM100_TMEM_STORE_16dp256b4x "{%1, %2, %3, %4," "%5, %6, %7, %8," "%9, %10, %11, %12," - "%13, %14, %15, %16};\n" - : + "%13, %14, %15, %16};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -4227,8 +4229,8 @@ struct SM100_TMEM_STORE_16dp256b4x_16b "{%1, %2, %3, %4," "%5, %6, %7, %8," "%9, %10, %11, %12," - "%13, %14, %15, %16};\n" - : + "%13, %14, %15, %16};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -4268,8 +4270,8 @@ struct SM100_TMEM_STORE_16dp256b8x "%17, %18, %19, %20," "%21, %22, %23, %24," "%25, %26, %27, %28," - "%29, %30, %31, %32};\n" - : + "%29, %30, %31, %32};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -4313,8 +4315,8 @@ struct SM100_TMEM_STORE_16dp256b8x_16b "%17, %18, %19, %20," "%21, %22, %23, %24," "%25, %26, %27, %28," - "%29, %30, %31, %32};\n" - : + "%29, %30, %31, %32};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -4374,8 +4376,8 @@ struct SM100_TMEM_STORE_16dp256b16x "%49, %50, %51, %52," "%53, %54, %55, %56," "%57, %58, %59, %60," - "%61, %62, %63, %64};\n" - : + "%61, %62, %63, %64};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -4443,8 +4445,8 @@ struct SM100_TMEM_STORE_16dp256b16x_16b "%49, %50, %51, %52," "%53, %54, %55, %56," "%57, %58, %59, %60," - "%61, %62, %63, %64};\n" - : + "%61, %62, %63, %64};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -4544,8 +4546,8 @@ struct SM100_TMEM_STORE_16dp256b32x "%113, %114, %115, %116," "%117, %118, %119, %120," "%121, %122, %123, %124," - "%125, %126, %127, %128};\n" - : + "%125, %126, %127, %128};\n" + : : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), "r"(src004), "r"(src005), "r"(src006), "r"(src007), "r"(src008), "r"(src009), "r"(src010), "r"(src011), @@ -4661,8 +4663,8 @@ struct SM100_TMEM_STORE_16dp256b32x_16b "%113, %114, %115, %116," "%117, %118, %119, %120," "%121, %122, %123, %124," - "%125, %126, %127, %128};\n" - : + "%125, %126, %127, %128};\n" + : : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), "r"(src004), "r"(src005), "r"(src006), "r"(src007), "r"(src008), "r"(src009), "r"(src010), "r"(src011), @@ -4716,8 +4718,8 @@ struct SM100_TMEM_STORE_16dp128b1x #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x128b.x1.b32" "[%0]," - "{%1, %2};\n" - : + "{%1, %2};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -4740,8 +4742,8 @@ struct SM100_TMEM_STORE_16dp128b1x_16b #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x128b.x1.unpack::16b.b32" "[%0]," - "{%1, %2};\n" - : + "{%1, %2};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -4764,8 +4766,8 @@ struct SM100_TMEM_STORE_16dp128b2x #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x128b.x2.b32" "[%0]," - "{%1, %2, %3, %4};\n" - : + "{%1, %2, %3, %4};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -4788,8 +4790,8 @@ struct SM100_TMEM_STORE_16dp128b2x_16b #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x128b.x2.unpack::16b.b32" "[%0]," - "{%1, %2, %3, %4};\n" - : + "{%1, %2, %3, %4};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -4814,8 +4816,8 @@ struct SM100_TMEM_STORE_16dp128b4x asm volatile ("tcgen05.st.sync.aligned.16x128b.x4.b32" "[%0]," "{%1, %2, %3, %4," - "%5, %6, %7, %8};\n" - : + "%5, %6, %7, %8};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); #else @@ -4841,8 +4843,8 @@ struct SM100_TMEM_STORE_16dp128b4x_16b asm volatile ("tcgen05.st.sync.aligned.16x128b.x4.unpack::16b.b32" "[%0]," "{%1, %2, %3, %4," - "%5, %6, %7, %8};\n" - : + "%5, %6, %7, %8};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); #else @@ -4872,8 +4874,8 @@ struct SM100_TMEM_STORE_16dp128b8x "{%1, %2, %3, %4," "%5, %6, %7, %8," "%9, %10, %11, %12," - "%13, %14, %15, %16};\n" - : + "%13, %14, %15, %16};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -4905,8 +4907,8 @@ struct SM100_TMEM_STORE_16dp128b8x_16b "{%1, %2, %3, %4," "%5, %6, %7, %8," "%9, %10, %11, %12," - "%13, %14, %15, %16};\n" - : + "%13, %14, %15, %16};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -4946,8 +4948,8 @@ struct SM100_TMEM_STORE_16dp128b16x "%17, %18, %19, %20," "%21, %22, %23, %24," "%25, %26, %27, %28," - "%29, %30, %31, %32};\n" - : + "%29, %30, %31, %32};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -4991,8 +4993,8 @@ struct SM100_TMEM_STORE_16dp128b16x_16b "%17, %18, %19, %20," "%21, %22, %23, %24," "%25, %26, %27, %28," - "%29, %30, %31, %32};\n" - : + "%29, %30, %31, %32};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -5052,8 +5054,8 @@ struct SM100_TMEM_STORE_16dp128b32x "%49, %50, %51, %52," "%53, %54, %55, %56," "%57, %58, %59, %60," - "%61, %62, %63, %64};\n" - : + "%61, %62, %63, %64};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -5121,8 +5123,8 @@ struct SM100_TMEM_STORE_16dp128b32x_16b "%49, %50, %51, %52," "%53, %54, %55, %56," "%57, %58, %59, %60," - "%61, %62, %63, %64};\n" - : + "%61, %62, %63, %64};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -5222,8 +5224,8 @@ struct SM100_TMEM_STORE_16dp128b64x "%113, %114, %115, %116," "%117, %118, %119, %120," "%121, %122, %123, %124," - "%125, %126, %127, %128};\n" - : + "%125, %126, %127, %128};\n" + : : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), "r"(src004), "r"(src005), "r"(src006), "r"(src007), "r"(src008), "r"(src009), "r"(src010), "r"(src011), @@ -5339,8 +5341,8 @@ struct SM100_TMEM_STORE_16dp128b64x_16b "%113, %114, %115, %116," "%117, %118, %119, %120," "%121, %122, %123, %124," - "%125, %126, %127, %128};\n" - : + "%125, %126, %127, %128};\n" + : : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), "r"(src004), "r"(src005), "r"(src006), "r"(src007), "r"(src008), "r"(src009), "r"(src010), "r"(src011), @@ -5394,8 +5396,8 @@ struct SM100_TMEM_STORE_16dp64b1x #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x64b.x1.b32" "[%0]," - "{%1};\n" - : + "{%1};\n" + : : "r"(dst_addr), "r"(src0) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -5418,8 +5420,8 @@ struct SM100_TMEM_STORE_16dp64b1x_16b #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x64b.x1.unpack::16b.b32" "[%0]," - "{%1};\n" - : + "{%1};\n" + : : "r"(dst_addr), "r"(src0) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -5442,8 +5444,8 @@ struct SM100_TMEM_STORE_16dp64b2x #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x64b.x2.b32" "[%0]," - "{%1, %2};\n" - : + "{%1, %2};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -5466,8 +5468,8 @@ struct SM100_TMEM_STORE_16dp64b2x_16b #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x64b.x2.unpack::16b.b32" "[%0]," - "{%1, %2};\n" - : + "{%1, %2};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -5490,8 +5492,8 @@ struct SM100_TMEM_STORE_16dp64b4x #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x64b.x4.b32" "[%0]," - "{%1, %2, %3, %4};\n" - : + "{%1, %2, %3, %4};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -5514,8 +5516,8 @@ struct SM100_TMEM_STORE_16dp64b4x_16b #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x64b.x4.unpack::16b.b32" "[%0]," - "{%1, %2, %3, %4};\n" - : + "{%1, %2, %3, %4};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -5540,8 +5542,8 @@ struct SM100_TMEM_STORE_16dp64b8x asm volatile ("tcgen05.st.sync.aligned.16x64b.x8.b32" "[%0]," "{%1, %2, %3, %4," - "%5, %6, %7, %8};\n" - : + "%5, %6, %7, %8};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); #else @@ -5567,8 +5569,8 @@ struct SM100_TMEM_STORE_16dp64b8x_16b asm volatile ("tcgen05.st.sync.aligned.16x64b.x8.unpack::16b.b32" "[%0]," "{%1, %2, %3, %4," - "%5, %6, %7, %8};\n" - : + "%5, %6, %7, %8};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); #else @@ -5598,8 +5600,8 @@ struct SM100_TMEM_STORE_16dp64b16x "{%1, %2, %3, %4," "%5, %6, %7, %8," "%9, %10, %11, %12," - "%13, %14, %15, %16};\n" - : + "%13, %14, %15, %16};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -5631,8 +5633,8 @@ struct SM100_TMEM_STORE_16dp64b16x_16b "{%1, %2, %3, %4," "%5, %6, %7, %8," "%9, %10, %11, %12," - "%13, %14, %15, %16};\n" - : + "%13, %14, %15, %16};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -5672,8 +5674,8 @@ struct SM100_TMEM_STORE_16dp64b32x "%17, %18, %19, %20," "%21, %22, %23, %24," "%25, %26, %27, %28," - "%29, %30, %31, %32};\n" - : + "%29, %30, %31, %32};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -5717,8 +5719,8 @@ struct SM100_TMEM_STORE_16dp64b32x_16b "%17, %18, %19, %20," "%21, %22, %23, %24," "%25, %26, %27, %28," - "%29, %30, %31, %32};\n" - : + "%29, %30, %31, %32};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -5778,8 +5780,8 @@ struct SM100_TMEM_STORE_16dp64b64x "%49, %50, %51, %52," "%53, %54, %55, %56," "%57, %58, %59, %60," - "%61, %62, %63, %64};\n" - : + "%61, %62, %63, %64};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -5847,8 +5849,8 @@ struct SM100_TMEM_STORE_16dp64b64x_16b "%49, %50, %51, %52," "%53, %54, %55, %56," "%57, %58, %59, %60," - "%61, %62, %63, %64};\n" - : + "%61, %62, %63, %64};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -5948,8 +5950,8 @@ struct SM100_TMEM_STORE_16dp64b128x "%113, %114, %115, %116," "%117, %118, %119, %120," "%121, %122, %123, %124," - "%125, %126, %127, %128};\n" - : + "%125, %126, %127, %128};\n" + : : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), "r"(src004), "r"(src005), "r"(src006), "r"(src007), "r"(src008), "r"(src009), "r"(src010), "r"(src011), @@ -6065,8 +6067,8 @@ struct SM100_TMEM_STORE_16dp64b128x_16b "%113, %114, %115, %116," "%117, %118, %119, %120," "%121, %122, %123, %124," - "%125, %126, %127, %128};\n" - : + "%125, %126, %127, %128};\n" + : : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), "r"(src004), "r"(src005), "r"(src006), "r"(src007), "r"(src008), "r"(src009), "r"(src010), "r"(src011), @@ -6120,8 +6122,8 @@ struct SM100_TMEM_STORE_16dp32b1x #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x1.b32" "[%0] , 1," - "{%1};\n" - : + "{%1};\n" + : : "r"(dst_addr), "r"(src0) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -6144,8 +6146,8 @@ struct SM100_TMEM_STORE_16dp32b1x_16b #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x1.unpack::16b.b32" "[%0] , 2," - "{%1};\n" - : + "{%1};\n" + : : "r"(dst_addr), "r"(src0) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -6168,8 +6170,8 @@ struct SM100_TMEM_STORE_16dp32b2x #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x2.b32" "[%0] , 2," - "{%1, %2};\n" - : + "{%1, %2};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -6192,8 +6194,8 @@ struct SM100_TMEM_STORE_16dp32b2x_16b #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x2.unpack::16b.b32" "[%0] , 4," - "{%1, %2};\n" - : + "{%1, %2};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -6216,8 +6218,8 @@ struct SM100_TMEM_STORE_16dp32b4x #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x4.b32" "[%0] , 4," - "{%1, %2, %3, %4};\n" - : + "{%1, %2, %3, %4};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -6240,8 +6242,8 @@ struct SM100_TMEM_STORE_16dp32b4x_16b #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x4.unpack::16b.b32" "[%0] , 8," - "{%1, %2, %3, %4};\n" - : + "{%1, %2, %3, %4};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -6266,8 +6268,8 @@ struct SM100_TMEM_STORE_16dp32b8x asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x8.b32" "[%0] , 8," "{%1, %2, %3, %4," - "%5, %6, %7, %8};\n" - : + "%5, %6, %7, %8};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); #else @@ -6293,8 +6295,8 @@ struct SM100_TMEM_STORE_16dp32b8x_16b asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x8.unpack::16b.b32" "[%0] , 16," "{%1, %2, %3, %4," - "%5, %6, %7, %8};\n" - : + "%5, %6, %7, %8};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); #else @@ -6324,8 +6326,8 @@ struct SM100_TMEM_STORE_16dp32b16x "{%1, %2, %3, %4," "%5, %6, %7, %8," "%9, %10, %11, %12," - "%13, %14, %15, %16};\n" - : + "%13, %14, %15, %16};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -6357,8 +6359,8 @@ struct SM100_TMEM_STORE_16dp32b16x_16b "{%1, %2, %3, %4," "%5, %6, %7, %8," "%9, %10, %11, %12," - "%13, %14, %15, %16};\n" - : + "%13, %14, %15, %16};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -6398,8 +6400,8 @@ struct SM100_TMEM_STORE_16dp32b32x "%17, %18, %19, %20," "%21, %22, %23, %24," "%25, %26, %27, %28," - "%29, %30, %31, %32};\n" - : + "%29, %30, %31, %32};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -6443,8 +6445,8 @@ struct SM100_TMEM_STORE_16dp32b32x_16b "%17, %18, %19, %20," "%21, %22, %23, %24," "%25, %26, %27, %28," - "%29, %30, %31, %32};\n" - : + "%29, %30, %31, %32};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -6504,8 +6506,8 @@ struct SM100_TMEM_STORE_16dp32b64x "%49, %50, %51, %52," "%53, %54, %55, %56," "%57, %58, %59, %60," - "%61, %62, %63, %64};\n" - : + "%61, %62, %63, %64};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -6573,8 +6575,8 @@ struct SM100_TMEM_STORE_16dp32b64x_16b "%49, %50, %51, %52," "%53, %54, %55, %56," "%57, %58, %59, %60," - "%61, %62, %63, %64};\n" - : + "%61, %62, %63, %64};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -6674,8 +6676,8 @@ struct SM100_TMEM_STORE_16dp32b128x "%113, %114, %115, %116," "%117, %118, %119, %120," "%121, %122, %123, %124," - "%125, %126, %127, %128};\n" - : + "%125, %126, %127, %128};\n" + : : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), "r"(src004), "r"(src005), "r"(src006), "r"(src007), "r"(src008), "r"(src009), "r"(src010), "r"(src011), @@ -6791,8 +6793,8 @@ struct SM100_TMEM_STORE_16dp32b128x_16b "%113, %114, %115, %116," "%117, %118, %119, %120," "%121, %122, %123, %124," - "%125, %126, %127, %128};\n" - : + "%125, %126, %127, %128};\n" + : : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), "r"(src004), "r"(src005), "r"(src006), "r"(src007), "r"(src008), "r"(src009), "r"(src010), "r"(src011), @@ -6846,8 +6848,8 @@ struct SM100_TMEM_STORE_32dp32b1x #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.32x32b.x1.b32" "[%0]," - "{%1};\n" - : + "{%1};\n" + : : "r"(dst_addr), "r"(src0) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -6870,8 +6872,8 @@ struct SM100_TMEM_STORE_32dp32b1x_16b #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.32x32b.x1.unpack::16b.b32" "[%0]," - "{%1};\n" - : + "{%1};\n" + : : "r"(dst_addr), "r"(src0) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -6894,8 +6896,8 @@ struct SM100_TMEM_STORE_32dp32b2x #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.32x32b.x2.b32" "[%0]," - "{%1, %2};\n" - : + "{%1, %2};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -6918,8 +6920,8 @@ struct SM100_TMEM_STORE_32dp32b2x_16b #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.32x32b.x2.unpack::16b.b32" "[%0]," - "{%1, %2};\n" - : + "{%1, %2};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -6942,8 +6944,8 @@ struct SM100_TMEM_STORE_32dp32b4x #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.32x32b.x4.b32" "[%0]," - "{%1, %2, %3, %4};\n" - : + "{%1, %2, %3, %4};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -6966,8 +6968,8 @@ struct SM100_TMEM_STORE_32dp32b4x_16b #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.32x32b.x4.unpack::16b.b32" "[%0]," - "{%1, %2, %3, %4};\n" - : + "{%1, %2, %3, %4};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -6992,8 +6994,8 @@ struct SM100_TMEM_STORE_32dp32b8x asm volatile ("tcgen05.st.sync.aligned.32x32b.x8.b32" "[%0]," "{%1, %2, %3, %4," - "%5, %6, %7, %8};\n" - : + "%5, %6, %7, %8};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); #else @@ -7019,8 +7021,8 @@ struct SM100_TMEM_STORE_32dp32b8x_16b asm volatile ("tcgen05.st.sync.aligned.32x32b.x8.unpack::16b.b32" "[%0]," "{%1, %2, %3, %4," - "%5, %6, %7, %8};\n" - : + "%5, %6, %7, %8};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); #else @@ -7050,8 +7052,8 @@ struct SM100_TMEM_STORE_32dp32b16x "{%1, %2, %3, %4," "%5, %6, %7, %8," "%9, %10, %11, %12," - "%13, %14, %15, %16};\n" - : + "%13, %14, %15, %16};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -7083,8 +7085,8 @@ struct SM100_TMEM_STORE_32dp32b16x_16b "{%1, %2, %3, %4," "%5, %6, %7, %8," "%9, %10, %11, %12," - "%13, %14, %15, %16};\n" - : + "%13, %14, %15, %16};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -7124,8 +7126,8 @@ struct SM100_TMEM_STORE_32dp32b32x "%17, %18, %19, %20," "%21, %22, %23, %24," "%25, %26, %27, %28," - "%29, %30, %31, %32};\n" - : + "%29, %30, %31, %32};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -7169,8 +7171,8 @@ struct SM100_TMEM_STORE_32dp32b32x_16b "%17, %18, %19, %20," "%21, %22, %23, %24," "%25, %26, %27, %28," - "%29, %30, %31, %32};\n" - : + "%29, %30, %31, %32};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -7230,8 +7232,8 @@ struct SM100_TMEM_STORE_32dp32b64x "%49, %50, %51, %52," "%53, %54, %55, %56," "%57, %58, %59, %60," - "%61, %62, %63, %64};\n" - : + "%61, %62, %63, %64};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -7299,8 +7301,8 @@ struct SM100_TMEM_STORE_32dp32b64x_16b "%49, %50, %51, %52," "%53, %54, %55, %56," "%57, %58, %59, %60," - "%61, %62, %63, %64};\n" - : + "%61, %62, %63, %64};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -7400,8 +7402,8 @@ struct SM100_TMEM_STORE_32dp32b128x "%113, %114, %115, %116," "%117, %118, %119, %120," "%121, %122, %123, %124," - "%125, %126, %127, %128};\n" - : + "%125, %126, %127, %128};\n" + : : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), "r"(src004), "r"(src005), "r"(src006), "r"(src007), "r"(src008), "r"(src009), "r"(src010), "r"(src011), @@ -7517,8 +7519,8 @@ struct SM100_TMEM_STORE_32dp32b128x_16b "%113, %114, %115, %116," "%117, %118, %119, %120," "%121, %122, %123, %124," - "%125, %126, %127, %128};\n" - : + "%125, %126, %127, %128};\n" + : : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), "r"(src004), "r"(src005), "r"(src006), "r"(src007), "r"(src008), "r"(src009), "r"(src010), "r"(src011), @@ -7561,7 +7563,8 @@ struct SM100_TMEM_STORE_32dp32b128x_16b //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace cute +} // namespace SM100::TMEM::STORE //////////////////////////////////////////////////////////////////////////////////////////////////// +} // end namespace cute diff --git a/include/cute/arch/mma_sm100.hpp b/include/cute/arch/mma_sm100.hpp index 2fa532d2ef..749da8167e 100644 --- a/include/cute/arch/mma_sm100.hpp +++ b/include/cute/arch/mma_sm100.hpp @@ -29,7 +29,6 @@ * **************************************************************************************************/ // - // #pragma once @@ -37,6 +36,48 @@ #include #include +#include + namespace cute { +struct SM100_2x1x1_F32F32F32F32 { + using DRegisters = float2[1]; + using ARegisters = float2[1]; + using BRegisters = float[1]; + using CRegisters = float2[1]; + + CUTE_HOST_DEVICE static void + fma(float2 & d01, + float2 const& a01, + float const& b0, + float2 const& c01) + { +#if defined(CUTE_ARCH_FFMA2_SM100_ENABLED) + cute::fma(d01, a01, make_float2(b0, b0), c01); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_2x1x1_F32F32F32F32 without CUTE_ARCH_FLOAT2_MATH_ENABLED"); +#endif + } +}; + +struct SM100_1x2x1_F32F32F32F32 { + using DRegisters = float2[1]; + using ARegisters = float[1]; + using BRegisters = float2[1]; + using CRegisters = float2[1]; + + CUTE_HOST_DEVICE static void + fma(float2 & d01, + float const& a0, + float2 const& b01, + float2 const& c01) + { +#if defined(CUTE_ARCH_FFMA2_SM100_ENABLED) + cute::fma(d01, make_float2(a0, a0), b01, c01); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_1x2x1_F32F32F32F32 without CUTE_ARCH_FFMA2_SM100_ENABLED"); +#endif + } +}; + } // namespace cute diff --git a/include/cute/arch/mma_sm120.hpp b/include/cute/arch/mma_sm120.hpp index 84c09b8b93..1433a2c8d0 100644 --- a/include/cute/arch/mma_sm120.hpp +++ b/include/cute/arch/mma_sm120.hpp @@ -3245,7 +3245,7 @@ rr_blockscaled_op_selector_sm120() { if constexpr (UseF8F6F4) { return SM120::BLOCKSCALED::SM120_16x8x32_TN_VS{}; - } + } else{ return SM120::BLOCKSCALED::SM120_16x8x64_TN_VS{}; } diff --git a/include/cute/arch/mma_sm89.hpp b/include/cute/arch/mma_sm89.hpp new file mode 100644 index 0000000000..85d7bb64ae --- /dev/null +++ b/include/cute/arch/mma_sm89.hpp @@ -0,0 +1,180 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// + +// +#pragma once + +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4) +# define CUTE_ARCH_MMA_F32_SM89_SUPPORTED +#endif + +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8) +# define CUTE_ARCH_MMA_F16_SM89_SUPPORTED +#endif + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) +# if defined(CUTE_ARCH_MMA_F32_SM89_SUPPORTED) +# define CUTE_ARCH_MMA_F32_SM89_ENABLED +# endif + +# if defined(CUTE_ARCH_MMA_F16_SM89_SUPPORTED) +# define CUTE_ARCH_MMA_F16_SM89_ENABLED +# endif +#endif + +//////////////////////////////////////////////////////////////////////////////// + +namespace cute { +// MMA 16x8x32 TN +struct SM89_16x8x32_F32E4M3E4M3F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3) + { +#if defined(CUTE_ARCH_MMA_F32_SM89_ENABLED) + asm( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : + "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3) + ); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F32E4M3E4M3F32_TN without CUTE_ARCH_MMA_F32_SM89_ENABLED"); +#endif + } +}; + +struct SM89_16x8x32_F32E4M3E5M2F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3) + { +#if defined(CUTE_ARCH_MMA_F32_SM89_ENABLED) + asm( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e5m2.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : + "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3) + ); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F32E4M3E5M2F32_TN without CUTE_ARCH_MMA_F32_SM89_ENABLED"); +#endif + } +}; + +struct SM89_16x8x32_F32E5M2E5M2F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3) + { +#if defined(CUTE_ARCH_MMA_F32_SM89_ENABLED) + asm( + "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : + "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3) + ); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F32E5M2E5M2F32_TN without CUTE_ARCH_MMA_F32_SM89_ENABLED"); +#endif + } +}; + +struct SM89_16x8x32_F32E5M2E4M3F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3) + { +#if defined(CUTE_ARCH_MMA_F32_SM89_ENABLED) + asm( + "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : + "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3) + ); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F32E5M2E4M3F32_TN without CUTE_ARCH_MMA_F32_SM89_ENABLED"); +#endif + } +}; + +} // namespace cute diff --git a/include/cute/arch/simd_sm100.hpp b/include/cute/arch/simd_sm100.hpp index 1c07a31e6d..58d8810e47 100644 --- a/include/cute/arch/simd_sm100.hpp +++ b/include/cute/arch/simd_sm100.hpp @@ -37,7 +37,6 @@ #include #include #include - namespace cute { CUTE_HOST_DEVICE diff --git a/include/cute/arch/tmem_allocator_sm100.hpp b/include/cute/arch/tmem_allocator_sm100.hpp index 6cd9223b76..347a619508 100644 --- a/include/cute/arch/tmem_allocator_sm100.hpp +++ b/include/cute/arch/tmem_allocator_sm100.hpp @@ -28,19 +28,34 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ -// -// #pragma once #include -#include -#include - -#include +#include +#include +#include namespace cute::TMEM { +// +// TMEM Addressing Constants +// + +// 128 DP x 512 COL x uint32_t-addressing +using MAX_CAPACITY_BITS = Int<128*512*32>; + +// TMEM DP stride in bit-addressing (shift by 5 for conversion from uint32_t) +using DP_b = cute::constant; + +// TMEM DP stride in type-T addressing +template +using DP = cute::constant::OffsetShift)>; + +// +// TMEM Allocators +// + // All operations of this class require that only a single warp uniformly participates class Allocator1Sm { public: @@ -57,7 +72,7 @@ class Allocator1Sm { * @pre Must never be issued by more than one warp at the same time. * @pre For repeated allocations, the same warp must be used to issue all allocations. **/ - CUTLASS_DEVICE void + CUTE_HOST_DEVICE void allocate(int num_columns, uint32_t* dst_ptr) { #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) uint32_t dst_intptr = cute::cast_smem_ptr_to_uint(dst_ptr); @@ -77,8 +92,8 @@ class Allocator1Sm { asm volatile( "{\n\t" "tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1; \n\t" - "}" - : + "}" + : : "r"(tmem_ptr), "r"(num_columns)); #else CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED"); @@ -116,7 +131,7 @@ class Allocator2Sm { * @pre For repeated allocations, the same warp must be used to issue all allocations. * @pre The 2 warps from participating CTAs have the same logical warp ID. **/ - CUTLASS_DEVICE void + CUTE_HOST_DEVICE void allocate(int num_columns, uint32_t* dst_ptr) { #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) uint32_t dst_intptr = cute::cast_smem_ptr_to_uint(dst_ptr); @@ -130,7 +145,7 @@ class Allocator2Sm { } /** - * Frees the TMEM corresponding to the pointer and slice count provided. + * Frees the TMEM corresponding to the pointer and slice count provided. * Release the TMEM after checking that the CTA issuing the free does indeed own the corresponding slices. * @param tmem_ptr Base address of the TMEM address space being freed. * @param num_columns Number of columns being freed. Must be 32 <= num_columns <= 512 and power of 2. @@ -146,8 +161,8 @@ class Allocator2Sm { asm volatile( "{\n\t" "tcgen05.dealloc.cta_group::2.sync.aligned.b32 %0, %1; \n\t" - "}" - : + "}" + : : "r"(tmem_ptr), "r"(num_columns)); #else CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED"); diff --git a/include/cute/arch/util.hpp b/include/cute/arch/util.hpp index b0899f7a83..c9ff9ef878 100644 --- a/include/cute/arch/util.hpp +++ b/include/cute/arch/util.hpp @@ -88,7 +88,7 @@ namespace cute { /// CUTE helper to cast SMEM pointer to unsigned -CUTE_DEVICE +CUTE_HOST_DEVICE uint32_t cast_smem_ptr_to_uint(void const* const ptr) { diff --git a/include/cute/atom/copy_traits_sm100.hpp b/include/cute/atom/copy_traits_sm100.hpp index 6a767ae3c0..594149d4fd 100644 --- a/include/cute/atom/copy_traits_sm100.hpp +++ b/include/cute/atom/copy_traits_sm100.hpp @@ -28,13 +28,11 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ -// - -// #pragma once #include +#include #include #include @@ -230,92 +228,11 @@ struct Copy_Traits using RefLayout = SrcLayout; }; -namespace TMEM { - using MAX_CAPACITY_BITS = Int<128*512*32>; // 128 DP x 512 COL x uint32_t-addressing - - template // TMEM DP stride in type-T addressing - using DP = cute::constant::OffsetShift)>; - - using DP_b = cute::constant; // TMEM DP stride in bit-addressing (shift by 5 for conversion from uint32_t) -} - -// TMEM_LOAD copy_unpack -template -struct TMEM_LOAD_Unpack -{ - template - CUTE_HOST_DEVICE friend constexpr void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_tmem::value, "Expected TMEM src."); - static_assert(is_rmem::value, "Expected RMEM dst."); - - using SrcType = typename TS::value_type; - CUTE_STATIC_ASSERT_V((coalesce(layout(src)) == coalesce(upcast::value>(typename Copy_Traits::ValID{}))), - "Expected src to have the specific TMEM layout required by CopyOp."); - - uint32_t tmem_addr = raw_pointer_cast(src.data()); - - using RegTypeDst = typename remove_extent::type; - Tensor rD = recast(dst); - - constexpr int RegNumDst = extent::value; - CUTE_STATIC_ASSERT_V(size(rD) == Int{}, - "In CopyAtom, dst layout doesn't vectorize into registers. This dst layout is incompatible with this CopyOp."); - - // thread idx <=> DP lane assert. - // ASSERT TMEM_LOAD thread attemping to access DP lane within sub-partition. -#if defined(__CUDA_ARCH__) && !defined(NDEBUG) - assert(((uint32_t(threadIdx.x) / 32) % 4) == (((tmem_addr >> 16) / 32) % 4)); -#endif - - detail::explode(CopyOp::copy, - &tmem_addr, seq<0>{}, - rD, make_seq{}); - } -}; - -// TMEM_STORE copy_unpack -template -struct TMEM_STORE_Unpack -{ - template - CUTE_HOST_DEVICE friend constexpr void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_rmem::value, "Expected RMEM src."); - static_assert(is_tmem::value, "Expected TMEM dst."); - - using RegTypeSrc = typename remove_extent::type; - Tensor rS = recast(src); - - constexpr int RegNumSrc = extent::value; - CUTE_STATIC_ASSERT_V(size(rS) == Int{}, - "In CopyAtom, src layout doesn't vectorize into registers. This src layout is incompatible with this tiled copy."); - - using DstType = typename TD::value_type; - CUTE_STATIC_ASSERT_V((coalesce(layout(dst)) == coalesce(upcast::value>(typename Copy_Traits::ValID{}))), - "Expected dst to have the specific TMEM layout required by CopyOp."); - - uint32_t tmem_addr = raw_pointer_cast(dst.data()); - - // thread idx <=> DP lane assert. - // ASSERT TMEM_LOAD thread attemping to access DP lane within sub-partition. -#if defined(__CUDA_ARCH__) && !defined(NDEBUG) - assert(((uint32_t(threadIdx.x) / 32) % 4) == (((tmem_addr >> 16) / 32) % 4)); -#endif - - detail::explode(CopyOp::copy, - rS, make_seq{}, - &tmem_addr, seq<0>{}); - } -}; +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMEM Traits and Utilities +// +//////////////////////////////////////////////////////////////////////////////////////////////////// template struct Copy_Atom; @@ -418,817 +335,162 @@ make_tmem_warp_partitioner(Tensor const& tmem) return make_tiler_impl(layout_tv, tiler); } -} // end namespace cute +namespace SM100::TMEM::LOAD { + +// +// Specialized copy_unpack implementation for SM100::TMEM::LOAD instructions +// + +template +CUTE_HOST_DEVICE constexpr +void +copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) +{ + static_assert(is_tmem::value, "Expected TMEM src."); + static_assert(is_rmem::value, "Expected RMEM dst."); + + using SrcType = typename TS::value_type; + CUTE_STATIC_ASSERT_V((coalesce(layout(src)) == coalesce(upcast::value>(typename Copy_Traits::ValID{}))), + "Expected src to have the specific TMEM layout required by CopyOp."); + + uint32_t tmem_addr = raw_pointer_cast(src.data()); + + using RegTypeDst = typename remove_extent::type; + Tensor rD = recast(dst); + + constexpr int RegNumDst = extent::value; + CUTE_STATIC_ASSERT_V(size(rD) == Int{}, + "In CopyAtom, dst layout doesn't vectorize into registers. This dst layout is incompatible with this CopyOp."); + + // thread idx <=> DP lane assert. + // ASSERT TMEM_LOAD thread attemping to access DP lane within sub-partition. +#if defined(__CUDA_ARCH__) && !defined(NDEBUG) + assert(((uint32_t(threadIdx.x) / 32) % 4) == (((tmem_addr >> 16) / 32) % 4)); +#endif + + detail::explode(CopyOp::copy, + &tmem_addr, seq<0>{}, + rD, make_seq{}); +} + +} // end namespace SM100::TMEM::LOAD + +namespace SM100::TMEM::STORE { + +// +// Specialized copy_unpack implementation for SM100::TMEM::STORE instructions +// + +template +CUTE_HOST_DEVICE constexpr +void +copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) +{ + static_assert(is_rmem::value, "Expected RMEM src."); + static_assert(is_tmem::value, "Expected TMEM dst."); + + using RegTypeSrc = typename remove_extent::type; + Tensor rS = recast(src); + + constexpr int RegNumSrc = extent::value; + CUTE_STATIC_ASSERT_V(size(rS) == Int{}, + "In CopyAtom, src layout doesn't vectorize into registers. This src layout is incompatible with this tiled copy."); + + using DstType = typename TD::value_type; + CUTE_STATIC_ASSERT_V((coalesce(layout(dst)) == coalesce(upcast::value>(typename Copy_Traits::ValID{}))), + "Expected dst to have the specific TMEM layout required by CopyOp."); + + uint32_t tmem_addr = raw_pointer_cast(dst.data()); + + // thread idx <=> DP lane assert. + // ASSERT TMEM_LOAD thread attemping to access DP lane within sub-partition. +#if defined(__CUDA_ARCH__) && !defined(NDEBUG) + assert(((uint32_t(threadIdx.x) / 32) % 4) == (((tmem_addr >> 16) / 32) % 4)); +#endif + + detail::explode(CopyOp::copy, + rS, make_seq{}, + &tmem_addr, seq<0>{}); +} + +} // end namespace SM100::TMEM::STORE + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMEM_LOAD Copy Traits +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////////////////////////////// -namespace cute { +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b1x; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + // Logical bit id to bit idx (address) + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_64, _2>>, + Stride,Stride< _1,_2048>>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2>>, + Stride,Stride< _1,_2048>>>; + using RefLayout = SrcLayout; +}; + //////////////////////////////////////////////////////////////////////////////////////////////////// -namespace TMEM { +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b2x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _2>>, + Stride,Stride< _1,_4096,_256>>>; + using RefLayout = SrcLayout; +}; //////////////////////////////////////////////////////////////////////////////////////////////////// -// Given a 1x tmem copy op, returns the widest repeated variant that divides the specified bits in the N-mode -template -CUTE_HOST_DEVICE constexpr -auto -op_repeater() -{ - if constexpr (cute::is_same_v) { - if constexpr (bits_n % (256 * 32) == 0) { - return SM100_TMEM_LOAD_16dp256b32x{}; - } - else if constexpr (bits_n % (256 * 16) == 0) { - return SM100_TMEM_LOAD_16dp256b16x{}; - } - else if constexpr (bits_n % (256 * 8) == 0) { - return SM100_TMEM_LOAD_16dp256b8x{}; - } - else if constexpr (bits_n % (256 * 4) == 0) { - return SM100_TMEM_LOAD_16dp256b4x{}; - } - else if constexpr (bits_n % (256 * 2) == 0) { - return SM100_TMEM_LOAD_16dp256b2x{}; - } - else if constexpr (bits_n % (256 * 1) == 0) { - return SM100_TMEM_LOAD_16dp256b1x{}; - } - } - else if constexpr (cute::is_same_v) { - if constexpr (bits_n % (256 * 32) == 0) { - return SM100_TMEM_LOAD_16dp256b32x_16b{}; - } - else if constexpr (bits_n % (256 * 16) == 0) { - return SM100_TMEM_LOAD_16dp256b16x_16b{}; - } - else if constexpr (bits_n % (256 * 8) == 0) { - return SM100_TMEM_LOAD_16dp256b8x_16b{}; - } - else if constexpr (bits_n % (256 * 4) == 0) { - return SM100_TMEM_LOAD_16dp256b4x_16b{}; - } - else if constexpr (bits_n % (256 * 2) == 0) { - return SM100_TMEM_LOAD_16dp256b2x_16b{}; - } - else if constexpr (bits_n % (256 * 1) == 0) { - return SM100_TMEM_LOAD_16dp256b1x_16b{}; - } - } - else if constexpr (cute::is_same_v) { - if constexpr (bits_n % (128 * 64) == 0) { - return SM100_TMEM_LOAD_16dp128b64x{}; - } - else if constexpr (bits_n % (128 * 32) == 0) { - return SM100_TMEM_LOAD_16dp128b32x{}; - } - else if constexpr (bits_n % (128 * 16) == 0) { - return SM100_TMEM_LOAD_16dp128b16x{}; - } - else if constexpr (bits_n % (128 * 8) == 0) { - return SM100_TMEM_LOAD_16dp128b8x{}; - } - else if constexpr (bits_n % (128 * 4) == 0) { - return SM100_TMEM_LOAD_16dp128b4x{}; - } - else if constexpr (bits_n % (128 * 2) == 0) { - return SM100_TMEM_LOAD_16dp128b2x{}; - } - else if constexpr (bits_n % (128 * 1) == 0) { - return SM100_TMEM_LOAD_16dp128b1x{}; - } - } - else if constexpr (cute::is_same_v) { - if constexpr (bits_n % (128 * 64) == 0) { - return SM100_TMEM_LOAD_16dp128b64x_16b{}; - } - else if constexpr (bits_n % (128 * 32) == 0) { - return SM100_TMEM_LOAD_16dp128b32x_16b{}; - } - else if constexpr (bits_n % (128 * 16) == 0) { - return SM100_TMEM_LOAD_16dp128b16x_16b{}; - } - else if constexpr (bits_n % (128 * 8) == 0) { - return SM100_TMEM_LOAD_16dp128b8x_16b{}; - } - else if constexpr (bits_n % (128 * 4) == 0) { - return SM100_TMEM_LOAD_16dp128b4x_16b{}; - } - else if constexpr (bits_n % (128 * 2) == 0) { - return SM100_TMEM_LOAD_16dp128b2x_16b{}; - } - else if constexpr (bits_n % (128 * 1) == 0) { - return SM100_TMEM_LOAD_16dp128b1x_16b{}; - } - } - else if constexpr (cute::is_same_v) { - if constexpr (bits_n % (64 * 128) == 0) { - return SM100_TMEM_LOAD_16dp64b128x{}; - } - else if constexpr (bits_n % (64 * 64) == 0) { - return SM100_TMEM_LOAD_16dp64b64x{}; - } - else if constexpr (bits_n % (64 * 32) == 0) { - return SM100_TMEM_LOAD_16dp64b32x{}; - } - else if constexpr (bits_n % (64 * 16) == 0) { - return SM100_TMEM_LOAD_16dp64b16x{}; - } - else if constexpr (bits_n % (64 * 8) == 0) { - return SM100_TMEM_LOAD_16dp64b8x{}; - } - else if constexpr (bits_n % (64 * 4) == 0) { - return SM100_TMEM_LOAD_16dp64b4x{}; - } - else if constexpr (bits_n % (64 * 2) == 0) { - return SM100_TMEM_LOAD_16dp64b2x{}; - } - else if constexpr (bits_n % (64 * 1) == 0) { - return SM100_TMEM_LOAD_16dp64b1x{}; - } - } - else if constexpr (cute::is_same_v) { - if constexpr (bits_n % (64 * 128) == 0) { - return SM100_TMEM_LOAD_16dp64b128x_16b{}; - } - else if constexpr (bits_n % (64 * 64) == 0) { - return SM100_TMEM_LOAD_16dp64b64x_16b{}; - } - else if constexpr (bits_n % (64 * 32) == 0) { - return SM100_TMEM_LOAD_16dp64b32x_16b{}; - } - else if constexpr (bits_n % (64 * 16) == 0) { - return SM100_TMEM_LOAD_16dp64b16x_16b{}; - } - else if constexpr (bits_n % (64 * 8) == 0) { - return SM100_TMEM_LOAD_16dp64b8x_16b{}; - } - else if constexpr (bits_n % (64 * 4) == 0) { - return SM100_TMEM_LOAD_16dp64b4x_16b{}; - } - else if constexpr (bits_n % (64 * 2) == 0) { - return SM100_TMEM_LOAD_16dp64b2x_16b{}; - } - else if constexpr (bits_n % (64 * 1) == 0) { - return SM100_TMEM_LOAD_16dp64b1x_16b{}; - } - } - else if constexpr (cute::is_same_v) { - if constexpr (bits_n % (64 * 128) == 0) { - return SM100_TMEM_LOAD_16dp32b128x{}; - } - else if constexpr (bits_n % (64 * 64) == 0) { - return SM100_TMEM_LOAD_16dp32b64x{}; - } - else if constexpr (bits_n % (64 * 32) == 0) { - return SM100_TMEM_LOAD_16dp32b32x{}; - } - else if constexpr (bits_n % (64 * 16) == 0) { - return SM100_TMEM_LOAD_16dp32b16x{}; - } - else if constexpr (bits_n % (64 * 8) == 0) { - return SM100_TMEM_LOAD_16dp32b8x{}; - } - else if constexpr (bits_n % (64 * 4) == 0) { - return SM100_TMEM_LOAD_16dp32b4x{}; - } - else if constexpr (bits_n % (64 * 2) == 0) { - return SM100_TMEM_LOAD_16dp32b2x{}; - } - else if constexpr (bits_n % (64 * 1) == 0) { - return SM100_TMEM_LOAD_16dp32b1x{}; - } - } - else if constexpr (cute::is_same_v) { - if constexpr (bits_n % (64 * 128) == 0) { - return SM100_TMEM_LOAD_16dp32b128x_16b{}; - } - else if constexpr (bits_n % (64 * 64) == 0) { - return SM100_TMEM_LOAD_16dp32b64x_16b{}; - } - else if constexpr (bits_n % (64 * 32) == 0) { - return SM100_TMEM_LOAD_16dp32b32x_16b{}; - } - else if constexpr (bits_n % (64 * 16) == 0) { - return SM100_TMEM_LOAD_16dp32b16x_16b{}; - } - else if constexpr (bits_n % (64 * 8) == 0) { - return SM100_TMEM_LOAD_16dp32b8x_16b{}; - } - else if constexpr (bits_n % (64 * 4) == 0) { - return SM100_TMEM_LOAD_16dp32b4x_16b{}; - } - else if constexpr (bits_n % (64 * 2) == 0) { - return SM100_TMEM_LOAD_16dp32b2x_16b{}; - } - else if constexpr (bits_n % (64 * 1) == 0) { - return SM100_TMEM_LOAD_16dp32b1x_16b{}; - } - } - else if constexpr (cute::is_same_v) { - if constexpr (bits_n % (32 * 128) == 0) { - return SM100_TMEM_LOAD_32dp32b128x{}; - } - else if constexpr (bits_n % (32 * 64) == 0) { - return SM100_TMEM_LOAD_32dp32b64x{}; - } - else if constexpr (bits_n % (32 * 32) == 0) { - return SM100_TMEM_LOAD_32dp32b32x{}; - } - else if constexpr (bits_n % (32 * 16) == 0) { - return SM100_TMEM_LOAD_32dp32b16x{}; - } - else if constexpr (bits_n % (32 * 8) == 0) { - return SM100_TMEM_LOAD_32dp32b8x{}; - } - else if constexpr (bits_n % (32 * 4) == 0) { - return SM100_TMEM_LOAD_32dp32b4x{}; - } - else if constexpr (bits_n % (32 * 2) == 0) { - return SM100_TMEM_LOAD_32dp32b2x{}; - } - else if constexpr (bits_n % (32 * 1) == 0) { - return SM100_TMEM_LOAD_32dp32b1x{}; - } - } - else if constexpr (cute::is_same_v) { - if constexpr (bits_n % (32 * 128) == 0) { - return SM100_TMEM_LOAD_32dp32b128x_16b{}; - } - else if constexpr (bits_n % (32 * 64) == 0) { - return SM100_TMEM_LOAD_32dp32b64x_16b{}; - } - else if constexpr (bits_n % (32 * 32) == 0) { - return SM100_TMEM_LOAD_32dp32b32x_16b{}; - } - else if constexpr (bits_n % (32 * 16) == 0) { - return SM100_TMEM_LOAD_32dp32b16x_16b{}; - } - else if constexpr (bits_n % (32 * 8) == 0) { - return SM100_TMEM_LOAD_32dp32b8x_16b{}; - } - else if constexpr (bits_n % (32 * 4) == 0) { - return SM100_TMEM_LOAD_32dp32b4x_16b{}; - } - else if constexpr (bits_n % (32 * 2) == 0) { - return SM100_TMEM_LOAD_32dp32b2x_16b{}; - } - else if constexpr (bits_n % (32 * 1) == 0) { - return SM100_TMEM_LOAD_32dp32b1x_16b{}; - } - } - else if constexpr (cute::is_same_v) { - if constexpr (bits_n % (256 * 32) == 0) { - return SM100_TMEM_STORE_16dp256b32x{}; - } - else if constexpr (bits_n % (256 * 16) == 0) { - return SM100_TMEM_STORE_16dp256b16x{}; - } - else if constexpr (bits_n % (256 * 8) == 0) { - return SM100_TMEM_STORE_16dp256b8x{}; - } - else if constexpr (bits_n % (256 * 4) == 0) { - return SM100_TMEM_STORE_16dp256b4x{}; - } - else if constexpr (bits_n % (256 * 2) == 0) { - return SM100_TMEM_STORE_16dp256b2x{}; - } - else if constexpr (bits_n % (256 * 1) == 0) { - return SM100_TMEM_STORE_16dp256b1x{}; - } - } - else if constexpr (cute::is_same_v) { - if constexpr (bits_n % (256 * 32) == 0) { - return SM100_TMEM_STORE_16dp256b32x_16b{}; - } - else if constexpr (bits_n % (256 * 16) == 0) { - return SM100_TMEM_STORE_16dp256b16x_16b{}; - } - else if constexpr (bits_n % (256 * 8) == 0) { - return SM100_TMEM_STORE_16dp256b8x_16b{}; - } - else if constexpr (bits_n % (256 * 4) == 0) { - return SM100_TMEM_STORE_16dp256b4x_16b{}; - } - else if constexpr (bits_n % (256 * 2) == 0) { - return SM100_TMEM_STORE_16dp256b2x_16b{}; - } - else if constexpr (bits_n % (256 * 1) == 0) { - return SM100_TMEM_STORE_16dp256b1x_16b{}; - } - } - else if constexpr (cute::is_same_v) { - if constexpr (bits_n % (128 * 64) == 0) { - return SM100_TMEM_STORE_16dp128b64x{}; - } - else if constexpr (bits_n % (128 * 32) == 0) { - return SM100_TMEM_STORE_16dp128b32x{}; - } - else if constexpr (bits_n % (128 * 16) == 0) { - return SM100_TMEM_STORE_16dp128b16x{}; - } - else if constexpr (bits_n % (128 * 8) == 0) { - return SM100_TMEM_STORE_16dp128b8x{}; - } - else if constexpr (bits_n % (128 * 4) == 0) { - return SM100_TMEM_STORE_16dp128b4x{}; - } - else if constexpr (bits_n % (128 * 2) == 0) { - return SM100_TMEM_STORE_16dp128b2x{}; - } - else if constexpr (bits_n % (128 * 1) == 0) { - return SM100_TMEM_STORE_16dp128b1x{}; - } - } - else if constexpr (cute::is_same_v) { - if constexpr (bits_n % (128 * 64) == 0) { - return SM100_TMEM_STORE_16dp128b64x_16b{}; - } - else if constexpr (bits_n % (128 * 32) == 0) { - return SM100_TMEM_STORE_16dp128b32x_16b{}; - } - else if constexpr (bits_n % (128 * 16) == 0) { - return SM100_TMEM_STORE_16dp128b16x_16b{}; - } - else if constexpr (bits_n % (128 * 8) == 0) { - return SM100_TMEM_STORE_16dp128b8x_16b{}; - } - else if constexpr (bits_n % (128 * 4) == 0) { - return SM100_TMEM_STORE_16dp128b4x_16b{}; - } - else if constexpr (bits_n % (128 * 2) == 0) { - return SM100_TMEM_STORE_16dp128b2x_16b{}; - } - else if constexpr (bits_n % (128 * 1) == 0) { - return SM100_TMEM_STORE_16dp128b1x_16b{}; - } - } - else if constexpr (cute::is_same_v) { - if constexpr (bits_n % (64 * 128) == 0) { - return SM100_TMEM_STORE_16dp64b128x{}; - } - else if constexpr (bits_n % (64 * 64) == 0) { - return SM100_TMEM_STORE_16dp64b64x{}; - } - else if constexpr (bits_n % (64 * 32) == 0) { - return SM100_TMEM_STORE_16dp64b32x{}; - } - else if constexpr (bits_n % (64 * 16) == 0) { - return SM100_TMEM_STORE_16dp64b16x{}; - } - else if constexpr (bits_n % (64 * 8) == 0) { - return SM100_TMEM_STORE_16dp64b8x{}; - } - else if constexpr (bits_n % (64 * 4) == 0) { - return SM100_TMEM_STORE_16dp64b4x{}; - } - else if constexpr (bits_n % (64 * 2) == 0) { - return SM100_TMEM_STORE_16dp64b2x{}; - } - else if constexpr (bits_n % (64 * 1) == 0) { - return SM100_TMEM_STORE_16dp64b1x{}; - } - } - else if constexpr (cute::is_same_v) { - if constexpr (bits_n % (64 * 128) == 0) { - return SM100_TMEM_STORE_16dp64b128x_16b{}; - } - else if constexpr (bits_n % (64 * 64) == 0) { - return SM100_TMEM_STORE_16dp64b64x_16b{}; - } - else if constexpr (bits_n % (64 * 32) == 0) { - return SM100_TMEM_STORE_16dp64b32x_16b{}; - } - else if constexpr (bits_n % (64 * 16) == 0) { - return SM100_TMEM_STORE_16dp64b16x_16b{}; - } - else if constexpr (bits_n % (64 * 8) == 0) { - return SM100_TMEM_STORE_16dp64b8x_16b{}; - } - else if constexpr (bits_n % (64 * 4) == 0) { - return SM100_TMEM_STORE_16dp64b4x_16b{}; - } - else if constexpr (bits_n % (64 * 2) == 0) { - return SM100_TMEM_STORE_16dp64b2x_16b{}; - } - else if constexpr (bits_n % (64 * 1) == 0) { - return SM100_TMEM_STORE_16dp64b1x_16b{}; - } - } - else if constexpr (cute::is_same_v) { - if constexpr (bits_n % (64 * 128) == 0) { - return SM100_TMEM_STORE_16dp32b128x{}; - } - else if constexpr (bits_n % (64 * 64) == 0) { - return SM100_TMEM_STORE_16dp32b64x{}; - } - else if constexpr (bits_n % (64 * 32) == 0) { - return SM100_TMEM_STORE_16dp32b32x{}; - } - else if constexpr (bits_n % (64 * 16) == 0) { - return SM100_TMEM_STORE_16dp32b16x{}; - } - else if constexpr (bits_n % (64 * 8) == 0) { - return SM100_TMEM_STORE_16dp32b8x{}; - } - else if constexpr (bits_n % (64 * 4) == 0) { - return SM100_TMEM_STORE_16dp32b4x{}; - } - else if constexpr (bits_n % (64 * 2) == 0) { - return SM100_TMEM_STORE_16dp32b2x{}; - } - else if constexpr (bits_n % (64 * 1) == 0) { - return SM100_TMEM_STORE_16dp32b1x{}; - } - } - else if constexpr (cute::is_same_v) { - if constexpr (bits_n % (64 * 128) == 0) { - return SM100_TMEM_STORE_16dp32b128x_16b{}; - } - else if constexpr (bits_n % (64 * 64) == 0) { - return SM100_TMEM_STORE_16dp32b64x_16b{}; - } - else if constexpr (bits_n % (64 * 32) == 0) { - return SM100_TMEM_STORE_16dp32b32x_16b{}; - } - else if constexpr (bits_n % (64 * 16) == 0) { - return SM100_TMEM_STORE_16dp32b16x_16b{}; - } - else if constexpr (bits_n % (64 * 8) == 0) { - return SM100_TMEM_STORE_16dp32b8x_16b{}; - } - else if constexpr (bits_n % (64 * 4) == 0) { - return SM100_TMEM_STORE_16dp32b4x_16b{}; - } - else if constexpr (bits_n % (64 * 2) == 0) { - return SM100_TMEM_STORE_16dp32b2x_16b{}; - } - else if constexpr (bits_n % (64 * 1) == 0) { - return SM100_TMEM_STORE_16dp32b1x_16b{}; - } - } - else if constexpr (cute::is_same_v) { - if constexpr (bits_n % (32 * 128) == 0) { - return SM100_TMEM_STORE_32dp32b128x{}; - } - else if constexpr (bits_n % (32 * 64) == 0) { - return SM100_TMEM_STORE_32dp32b64x{}; - } - else if constexpr (bits_n % (32 * 32) == 0) { - return SM100_TMEM_STORE_32dp32b32x{}; - } - else if constexpr (bits_n % (32 * 16) == 0) { - return SM100_TMEM_STORE_32dp32b16x{}; - } - else if constexpr (bits_n % (32 * 8) == 0) { - return SM100_TMEM_STORE_32dp32b8x{}; - } - else if constexpr (bits_n % (32 * 4) == 0) { - return SM100_TMEM_STORE_32dp32b4x{}; - } - else if constexpr (bits_n % (32 * 2) == 0) { - return SM100_TMEM_STORE_32dp32b2x{}; - } - else if constexpr (bits_n % (32 * 1) == 0) { - return SM100_TMEM_STORE_32dp32b1x{}; - } - } - else if constexpr (cute::is_same_v) { - if constexpr (bits_n % (32 * 128) == 0) { - return SM100_TMEM_STORE_32dp32b128x_16b{}; - } - else if constexpr (bits_n % (32 * 64) == 0) { - return SM100_TMEM_STORE_32dp32b64x_16b{}; - } - else if constexpr (bits_n % (32 * 32) == 0) { - return SM100_TMEM_STORE_32dp32b32x_16b{}; - } - else if constexpr (bits_n % (32 * 16) == 0) { - return SM100_TMEM_STORE_32dp32b16x_16b{}; - } - else if constexpr (bits_n % (32 * 8) == 0) { - return SM100_TMEM_STORE_32dp32b8x_16b{}; - } - else if constexpr (bits_n % (32 * 4) == 0) { - return SM100_TMEM_STORE_32dp32b4x_16b{}; - } - else if constexpr (bits_n % (32 * 2) == 0) { - return SM100_TMEM_STORE_32dp32b2x_16b{}; - } - else if constexpr (bits_n % (32 * 1) == 0) { - return SM100_TMEM_STORE_32dp32b1x_16b{}; - } - } - else { - static_assert(dependent_false, "Must pass 1x tmem copy operator"); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Select TMEM store corresponding to the provided TMEM load -template -CUTE_HOST_DEVICE constexpr auto -tmem_load_to_store(CopyOp) { - if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp256b1x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp256b1x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp256b2x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp256b2x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp256b4x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp256b4x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp256b8x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp256b8x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp256b16x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp256b16x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp256b32x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp256b32x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp128b1x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp128b1x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp128b2x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp128b2x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp128b4x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp128b4x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp128b8x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp128b8x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp128b16x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp128b16x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp128b32x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp128b32x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp128b64x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp128b64x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp64b1x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp64b1x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp64b2x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp64b2x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp64b4x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp64b4x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp64b8x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp64b8x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp64b16x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp64b16x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp64b32x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp64b32x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp64b64x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp64b64x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp64b128x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp64b128x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp32b1x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp32b1x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp32b2x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp32b2x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp32b4x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp32b4x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp32b8x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp32b8x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp32b16x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp32b16x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp32b32x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp32b32x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp32b64x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp32b64x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp32b128x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_16dp32b128x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_32dp32b1x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_32dp32b1x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_32dp32b2x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_32dp32b2x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_32dp32b4x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_32dp32b4x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_32dp32b8x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_32dp32b8x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_32dp32b16x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_32dp32b16x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_32dp32b32x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_32dp32b32x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_32dp32b64x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_32dp32b64x_16b{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_32dp32b128x{}; - } - else if constexpr (is_same_v) { - return SM100_TMEM_STORE_32dp32b128x_16b{}; - } - else { - static_assert(dependent_false, "No TMEM_STORE matching for provided TMEM_LOAD"); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace TMEM - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// -// TMEM_LOAD Copy Traits -// -//////////////////////////////////////////////////////////////////////////////////////////////////// - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - // Logical thread id to thread idx (warp) - using ThrID = Layout<_32>; - // Logical bit id to bit idx (address) - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, - Stride< _0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout,Shape <_64, _2>>, - Stride,Stride< _1,_2048>>>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_64, _2>>, - Stride,Stride< _1,_2048>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_64, _2, _2>>, - Stride,Stride< _1,_4096,_256>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b2x_16b; + +template <> +struct Copy_Traits { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -1242,9 +504,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b4x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -1258,9 +521,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b4x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -1274,9 +538,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b8x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -1290,9 +555,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b8x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -1306,9 +572,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b16x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -1322,9 +589,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b16x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -1338,9 +606,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b32x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -1354,9 +623,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b32x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -1370,9 +640,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b1x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -1386,9 +657,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b1x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -1402,9 +674,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b2x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -1418,9 +691,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b2x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -1434,9 +708,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b4x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -1450,9 +725,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b4x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -1466,9 +742,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b8x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -1482,9 +759,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b8x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -1498,9 +776,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b16x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -1514,9 +793,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b16x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -1530,9 +810,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b32x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -1546,9 +827,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b32x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -1562,9 +844,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b64x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -1578,9 +861,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b64x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -1594,9 +878,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b1x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -1610,9 +895,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b1x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -1626,9 +912,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b2x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -1642,9 +929,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b2x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -1658,9 +946,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b4x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -1674,9 +963,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b4x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -1690,9 +980,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b8x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -1706,9 +997,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b8x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -1722,9 +1014,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b16x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -1738,9 +1031,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b16x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -1754,9 +1048,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b32x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -1770,9 +1065,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b32x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -1786,9 +1082,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b64x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -1802,9 +1099,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b64x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -1818,9 +1116,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b128x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -1834,9 +1133,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b128x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -1850,9 +1150,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b1x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -1866,9 +1167,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b1x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -1882,9 +1184,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b2x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -1898,9 +1201,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b2x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -1914,9 +1218,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b4x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -1930,9 +1235,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b4x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -1946,9 +1252,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b8x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -1962,9 +1269,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b8x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -1978,9 +1286,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b16x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -1994,9 +1303,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b16x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -2010,9 +1320,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b32x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -2026,9 +1337,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b32x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -2042,9 +1354,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b64x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -2058,9 +1371,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b64x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -2074,9 +1388,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b128x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -2090,9 +1405,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b128x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _16>, @@ -2106,9 +1422,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b1x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -2122,9 +1439,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b1x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _32>, @@ -2138,9 +1456,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b2x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -2154,9 +1473,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b2x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _32>, @@ -2170,9 +1490,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b4x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -2186,9 +1507,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b4x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _32>, @@ -2202,9 +1524,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b8x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -2218,9 +1541,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b8x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _32>, @@ -2234,9 +1558,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b16x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -2250,9 +1575,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b16x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _32>, @@ -2266,9 +1592,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b32x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -2281,9 +1608,11 @@ struct Copy_Traits }; //////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b32x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _32>, @@ -2297,9 +1626,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -2313,9 +1643,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _32>, @@ -2329,9 +1660,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b128x; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, @@ -2344,9 +1676,11 @@ struct Copy_Traits }; //////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b128x_16b; + template <> struct Copy_Traits - : TMEM_LOAD_Unpack { using ThrID = Layout<_32>; using ValID = Layout, _32>, @@ -2368,9 +1702,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b1x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2381,9 +1716,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b1x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2394,9 +1730,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b2x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2407,9 +1744,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b2x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2420,9 +1758,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b4x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2433,9 +1772,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b4x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2446,9 +1786,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b8x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2459,9 +1800,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b8x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2472,9 +1814,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b16x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2485,9 +1828,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b16x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2498,9 +1842,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b32x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2511,9 +1856,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b32x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2524,9 +1870,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b1x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2537,9 +1884,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b1x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2550,9 +1898,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b2x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2563,9 +1912,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b2x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2576,9 +1926,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b4x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2589,9 +1940,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b4x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2602,9 +1954,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b8x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2615,9 +1968,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b8x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2628,9 +1982,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b16x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2641,9 +1996,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b16x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2654,9 +2010,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b32x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2667,9 +2024,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b32x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2680,9 +2038,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b64x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2693,9 +2052,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b64x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2706,9 +2066,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b1x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2719,9 +2080,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b1x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2732,9 +2094,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b2x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2745,9 +2108,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b2x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2758,9 +2122,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b4x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2771,9 +2136,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b4x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2784,9 +2150,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b8x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2797,9 +2164,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b8x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2810,9 +2178,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b16x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2823,9 +2192,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b16x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2836,9 +2206,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b32x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2849,9 +2220,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b32x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2862,9 +2234,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b64x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2875,9 +2248,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b64x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2888,9 +2262,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b128x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2901,9 +2276,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b128x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2914,9 +2290,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b1x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2927,9 +2304,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b1x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2940,9 +2318,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b2x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2953,9 +2332,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b2x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2966,9 +2346,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b4x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2979,9 +2360,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b4x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -2992,9 +2374,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b8x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -3005,9 +2388,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b8x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -3018,9 +2402,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b16x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -3031,9 +2416,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b16x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -3044,9 +2430,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b32x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -3057,9 +2444,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b32x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -3070,9 +2458,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b64x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -3083,9 +2472,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b64x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -3096,9 +2486,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b128x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -3109,9 +2500,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b128x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -3122,9 +2514,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b1x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -3135,9 +2528,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b1x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -3148,9 +2542,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b2x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -3161,9 +2556,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b2x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -3174,9 +2570,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b4x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -3187,9 +2584,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b4x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -3200,9 +2598,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b8x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -3213,9 +2612,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b8x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -3226,9 +2626,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b16x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -3239,9 +2640,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b16x_16b; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -3252,9 +2654,10 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b32x; + template <> struct Copy_Traits - : TMEM_STORE_Unpack { using ThrID = typename Copy_Traits::ThrID; using ValID = typename Copy_Traits::ValID; @@ -3265,76 +2668,841 @@ struct Copy_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b64x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b64x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b128x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b128x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace TMEM { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Given a 1x tmem copy op, returns the widest repeated variant that divides the specified bits in the N-mode +template +CUTE_HOST_DEVICE constexpr +auto +op_repeater() +{ + if constexpr (cute::is_same_v) { + if constexpr (bits_n % (256 * 32) == 0) { + return SM100_TMEM_LOAD_16dp256b32x{}; + } + else if constexpr (bits_n % (256 * 16) == 0) { + return SM100_TMEM_LOAD_16dp256b16x{}; + } + else if constexpr (bits_n % (256 * 8) == 0) { + return SM100_TMEM_LOAD_16dp256b8x{}; + } + else if constexpr (bits_n % (256 * 4) == 0) { + return SM100_TMEM_LOAD_16dp256b4x{}; + } + else if constexpr (bits_n % (256 * 2) == 0) { + return SM100_TMEM_LOAD_16dp256b2x{}; + } + else if constexpr (bits_n % (256 * 1) == 0) { + return SM100_TMEM_LOAD_16dp256b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (256 * 32) == 0) { + return SM100_TMEM_LOAD_16dp256b32x_16b{}; + } + else if constexpr (bits_n % (256 * 16) == 0) { + return SM100_TMEM_LOAD_16dp256b16x_16b{}; + } + else if constexpr (bits_n % (256 * 8) == 0) { + return SM100_TMEM_LOAD_16dp256b8x_16b{}; + } + else if constexpr (bits_n % (256 * 4) == 0) { + return SM100_TMEM_LOAD_16dp256b4x_16b{}; + } + else if constexpr (bits_n % (256 * 2) == 0) { + return SM100_TMEM_LOAD_16dp256b2x_16b{}; + } + else if constexpr (bits_n % (256 * 1) == 0) { + return SM100_TMEM_LOAD_16dp256b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (128 * 64) == 0) { + return SM100_TMEM_LOAD_16dp128b64x{}; + } + else if constexpr (bits_n % (128 * 32) == 0) { + return SM100_TMEM_LOAD_16dp128b32x{}; + } + else if constexpr (bits_n % (128 * 16) == 0) { + return SM100_TMEM_LOAD_16dp128b16x{}; + } + else if constexpr (bits_n % (128 * 8) == 0) { + return SM100_TMEM_LOAD_16dp128b8x{}; + } + else if constexpr (bits_n % (128 * 4) == 0) { + return SM100_TMEM_LOAD_16dp128b4x{}; + } + else if constexpr (bits_n % (128 * 2) == 0) { + return SM100_TMEM_LOAD_16dp128b2x{}; + } + else if constexpr (bits_n % (128 * 1) == 0) { + return SM100_TMEM_LOAD_16dp128b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (128 * 64) == 0) { + return SM100_TMEM_LOAD_16dp128b64x_16b{}; + } + else if constexpr (bits_n % (128 * 32) == 0) { + return SM100_TMEM_LOAD_16dp128b32x_16b{}; + } + else if constexpr (bits_n % (128 * 16) == 0) { + return SM100_TMEM_LOAD_16dp128b16x_16b{}; + } + else if constexpr (bits_n % (128 * 8) == 0) { + return SM100_TMEM_LOAD_16dp128b8x_16b{}; + } + else if constexpr (bits_n % (128 * 4) == 0) { + return SM100_TMEM_LOAD_16dp128b4x_16b{}; + } + else if constexpr (bits_n % (128 * 2) == 0) { + return SM100_TMEM_LOAD_16dp128b2x_16b{}; + } + else if constexpr (bits_n % (128 * 1) == 0) { + return SM100_TMEM_LOAD_16dp128b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_LOAD_16dp64b128x{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_LOAD_16dp64b64x{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_LOAD_16dp64b32x{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_LOAD_16dp64b16x{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_LOAD_16dp64b8x{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_LOAD_16dp64b4x{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_LOAD_16dp64b2x{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_LOAD_16dp64b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_LOAD_16dp64b128x_16b{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_LOAD_16dp64b64x_16b{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_LOAD_16dp64b32x_16b{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_LOAD_16dp64b16x_16b{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_LOAD_16dp64b8x_16b{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_LOAD_16dp64b4x_16b{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_LOAD_16dp64b2x_16b{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_LOAD_16dp64b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_LOAD_16dp32b128x{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_LOAD_16dp32b64x{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_LOAD_16dp32b32x{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_LOAD_16dp32b16x{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_LOAD_16dp32b8x{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_LOAD_16dp32b4x{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_LOAD_16dp32b2x{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_LOAD_16dp32b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_LOAD_16dp32b128x_16b{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_LOAD_16dp32b64x_16b{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_LOAD_16dp32b32x_16b{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_LOAD_16dp32b16x_16b{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_LOAD_16dp32b8x_16b{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_LOAD_16dp32b4x_16b{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_LOAD_16dp32b2x_16b{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_LOAD_16dp32b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (32 * 128) == 0) { + return SM100_TMEM_LOAD_32dp32b128x{}; + } + else if constexpr (bits_n % (32 * 64) == 0) { + return SM100_TMEM_LOAD_32dp32b64x{}; + } + else if constexpr (bits_n % (32 * 32) == 0) { + return SM100_TMEM_LOAD_32dp32b32x{}; + } + else if constexpr (bits_n % (32 * 16) == 0) { + return SM100_TMEM_LOAD_32dp32b16x{}; + } + else if constexpr (bits_n % (32 * 8) == 0) { + return SM100_TMEM_LOAD_32dp32b8x{}; + } + else if constexpr (bits_n % (32 * 4) == 0) { + return SM100_TMEM_LOAD_32dp32b4x{}; + } + else if constexpr (bits_n % (32 * 2) == 0) { + return SM100_TMEM_LOAD_32dp32b2x{}; + } + else if constexpr (bits_n % (32 * 1) == 0) { + return SM100_TMEM_LOAD_32dp32b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (32 * 128) == 0) { + return SM100_TMEM_LOAD_32dp32b128x_16b{}; + } + else if constexpr (bits_n % (32 * 64) == 0) { + return SM100_TMEM_LOAD_32dp32b64x_16b{}; + } + else if constexpr (bits_n % (32 * 32) == 0) { + return SM100_TMEM_LOAD_32dp32b32x_16b{}; + } + else if constexpr (bits_n % (32 * 16) == 0) { + return SM100_TMEM_LOAD_32dp32b16x_16b{}; + } + else if constexpr (bits_n % (32 * 8) == 0) { + return SM100_TMEM_LOAD_32dp32b8x_16b{}; + } + else if constexpr (bits_n % (32 * 4) == 0) { + return SM100_TMEM_LOAD_32dp32b4x_16b{}; + } + else if constexpr (bits_n % (32 * 2) == 0) { + return SM100_TMEM_LOAD_32dp32b2x_16b{}; + } + else if constexpr (bits_n % (32 * 1) == 0) { + return SM100_TMEM_LOAD_32dp32b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (256 * 32) == 0) { + return SM100_TMEM_STORE_16dp256b32x{}; + } + else if constexpr (bits_n % (256 * 16) == 0) { + return SM100_TMEM_STORE_16dp256b16x{}; + } + else if constexpr (bits_n % (256 * 8) == 0) { + return SM100_TMEM_STORE_16dp256b8x{}; + } + else if constexpr (bits_n % (256 * 4) == 0) { + return SM100_TMEM_STORE_16dp256b4x{}; + } + else if constexpr (bits_n % (256 * 2) == 0) { + return SM100_TMEM_STORE_16dp256b2x{}; + } + else if constexpr (bits_n % (256 * 1) == 0) { + return SM100_TMEM_STORE_16dp256b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (256 * 32) == 0) { + return SM100_TMEM_STORE_16dp256b32x_16b{}; + } + else if constexpr (bits_n % (256 * 16) == 0) { + return SM100_TMEM_STORE_16dp256b16x_16b{}; + } + else if constexpr (bits_n % (256 * 8) == 0) { + return SM100_TMEM_STORE_16dp256b8x_16b{}; + } + else if constexpr (bits_n % (256 * 4) == 0) { + return SM100_TMEM_STORE_16dp256b4x_16b{}; + } + else if constexpr (bits_n % (256 * 2) == 0) { + return SM100_TMEM_STORE_16dp256b2x_16b{}; + } + else if constexpr (bits_n % (256 * 1) == 0) { + return SM100_TMEM_STORE_16dp256b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (128 * 64) == 0) { + return SM100_TMEM_STORE_16dp128b64x{}; + } + else if constexpr (bits_n % (128 * 32) == 0) { + return SM100_TMEM_STORE_16dp128b32x{}; + } + else if constexpr (bits_n % (128 * 16) == 0) { + return SM100_TMEM_STORE_16dp128b16x{}; + } + else if constexpr (bits_n % (128 * 8) == 0) { + return SM100_TMEM_STORE_16dp128b8x{}; + } + else if constexpr (bits_n % (128 * 4) == 0) { + return SM100_TMEM_STORE_16dp128b4x{}; + } + else if constexpr (bits_n % (128 * 2) == 0) { + return SM100_TMEM_STORE_16dp128b2x{}; + } + else if constexpr (bits_n % (128 * 1) == 0) { + return SM100_TMEM_STORE_16dp128b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (128 * 64) == 0) { + return SM100_TMEM_STORE_16dp128b64x_16b{}; + } + else if constexpr (bits_n % (128 * 32) == 0) { + return SM100_TMEM_STORE_16dp128b32x_16b{}; + } + else if constexpr (bits_n % (128 * 16) == 0) { + return SM100_TMEM_STORE_16dp128b16x_16b{}; + } + else if constexpr (bits_n % (128 * 8) == 0) { + return SM100_TMEM_STORE_16dp128b8x_16b{}; + } + else if constexpr (bits_n % (128 * 4) == 0) { + return SM100_TMEM_STORE_16dp128b4x_16b{}; + } + else if constexpr (bits_n % (128 * 2) == 0) { + return SM100_TMEM_STORE_16dp128b2x_16b{}; + } + else if constexpr (bits_n % (128 * 1) == 0) { + return SM100_TMEM_STORE_16dp128b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_STORE_16dp64b128x{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_STORE_16dp64b64x{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_STORE_16dp64b32x{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_STORE_16dp64b16x{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_STORE_16dp64b8x{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_STORE_16dp64b4x{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_STORE_16dp64b2x{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_STORE_16dp64b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_STORE_16dp64b128x_16b{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_STORE_16dp64b64x_16b{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_STORE_16dp64b32x_16b{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_STORE_16dp64b16x_16b{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_STORE_16dp64b8x_16b{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_STORE_16dp64b4x_16b{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_STORE_16dp64b2x_16b{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_STORE_16dp64b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_STORE_16dp32b128x{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_STORE_16dp32b64x{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_STORE_16dp32b32x{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_STORE_16dp32b16x{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_STORE_16dp32b8x{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_STORE_16dp32b4x{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_STORE_16dp32b2x{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_STORE_16dp32b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_STORE_16dp32b128x_16b{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_STORE_16dp32b64x_16b{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_STORE_16dp32b32x_16b{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_STORE_16dp32b16x_16b{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_STORE_16dp32b8x_16b{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_STORE_16dp32b4x_16b{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_STORE_16dp32b2x_16b{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_STORE_16dp32b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (32 * 128) == 0) { + return SM100_TMEM_STORE_32dp32b128x{}; + } + else if constexpr (bits_n % (32 * 64) == 0) { + return SM100_TMEM_STORE_32dp32b64x{}; + } + else if constexpr (bits_n % (32 * 32) == 0) { + return SM100_TMEM_STORE_32dp32b32x{}; + } + else if constexpr (bits_n % (32 * 16) == 0) { + return SM100_TMEM_STORE_32dp32b16x{}; + } + else if constexpr (bits_n % (32 * 8) == 0) { + return SM100_TMEM_STORE_32dp32b8x{}; + } + else if constexpr (bits_n % (32 * 4) == 0) { + return SM100_TMEM_STORE_32dp32b4x{}; + } + else if constexpr (bits_n % (32 * 2) == 0) { + return SM100_TMEM_STORE_32dp32b2x{}; + } + else if constexpr (bits_n % (32 * 1) == 0) { + return SM100_TMEM_STORE_32dp32b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (32 * 128) == 0) { + return SM100_TMEM_STORE_32dp32b128x_16b{}; + } + else if constexpr (bits_n % (32 * 64) == 0) { + return SM100_TMEM_STORE_32dp32b64x_16b{}; + } + else if constexpr (bits_n % (32 * 32) == 0) { + return SM100_TMEM_STORE_32dp32b32x_16b{}; + } + else if constexpr (bits_n % (32 * 16) == 0) { + return SM100_TMEM_STORE_32dp32b16x_16b{}; + } + else if constexpr (bits_n % (32 * 8) == 0) { + return SM100_TMEM_STORE_32dp32b8x_16b{}; + } + else if constexpr (bits_n % (32 * 4) == 0) { + return SM100_TMEM_STORE_32dp32b4x_16b{}; + } + else if constexpr (bits_n % (32 * 2) == 0) { + return SM100_TMEM_STORE_32dp32b2x_16b{}; + } + else if constexpr (bits_n % (32 * 1) == 0) { + return SM100_TMEM_STORE_32dp32b1x_16b{}; + } + } + else { + static_assert(dependent_false, "Must pass 1x tmem copy operator"); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Select TMEM store corresponding to the provided TMEM load +template +CUTE_HOST_DEVICE constexpr auto +tmem_load_to_store(CopyOp) { + if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b1x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b1x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b2x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b2x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b4x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b4x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b8x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b8x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b16x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b16x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b32x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b32x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b1x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b1x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b2x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b2x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b4x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b4x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b8x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b8x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b16x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b16x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b32x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b32x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b64x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b64x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b1x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b1x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b2x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b2x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b4x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b4x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b8x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b8x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b16x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b16x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b32x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b32x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b64x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b64x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b128x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b128x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b1x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b1x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b2x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b2x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b4x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b4x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b8x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b8x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b16x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b16x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b32x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b32x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b64x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b64x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b128x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b128x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b1x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b1x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b2x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b2x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b4x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b4x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b8x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b8x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b16x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b16x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b32x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b32x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b64x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b64x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b128x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b128x_16b{}; + } + else { + static_assert(dependent_false, "No TMEM_STORE matching for provided TMEM_LOAD"); + } +} -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; +} // namespace TMEM //////////////////////////////////////////////////////////////////////////////////////////////////// -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - //////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - +// +// UTCCP Copy Traits +// //////////////////////////////////////////////////////////////////////////////////////////////////// -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// +namespace SM100::TMEM::UTCCP { -//////////////////////////////////////////////////////////////////////////////////////////////////// // -// UTCCP Copy Traits +// Specialized copy_unpack implementation for SM100::TMEM::UTCCP instructions // -//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE constexpr +void +copy_unpack(Copy_Traits const&, + Tensor const& src, + Tensor & dst) +{ + static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); + static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); + CopyOp::copy(src[0], raw_pointer_cast(dst.data())); +} + +} // end namespace SM100::TMEM::UTCCP // In the following UTCCP traits, the ValID is representing: // logical_bit_idx -> tmem_addr_offset. @@ -3344,131 +3512,76 @@ struct Copy_Traits // The last two modes provide boradcast transformation for 4x32DP and 2x64DP. // With above, the strides of first two modes are neccessary to be TMEM::DP_b and 1. // And the stride of the third mode in the SrcLayout must be zero. + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_128dp256bit_1cta; + template <> struct Copy_Traits { using ThrID = Layout<_1>; - // logical bit_idx -> tmem_addr using ValID = Layout, Stride>; - - // Map from (src-thr,src-val) to bit using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit using DstLayout = Layout, Stride<_0,_1>>; - // Reference map from (thr,val) to bit using RefLayout = DstLayout; +}; +//////////////////////////////////////////////////////////////////////////////////////////////////// - template - CUTE_HOST_DEVICE friend constexpr - void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); - static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); - SM100_UTCCP_128dp256bit_1cta::copy(src[0], raw_pointer_cast(dst.data())); - } -}; +using SM100::TMEM::UTCCP::SM100_UTCCP_128dp256bit_2cta; template <> struct Copy_Traits { using ThrID = Layout<_2>; - // logical bit_idx -> tmem_addr using ValID = typename Copy_Traits::ValID; - - // Map from (src-thr,src-val) to bit using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit using DstLayout = Layout, Stride<_0, _1>>; - // Reference map from (thr,val) to bit using RefLayout = DstLayout; +}; +//////////////////////////////////////////////////////////////////////////////////////////////////// - template - CUTE_HOST_DEVICE friend constexpr - void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); - static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); - SM100_UTCCP_128dp256bit_2cta::copy(src[0], raw_pointer_cast(dst.data())); - } -}; +using SM100::TMEM::UTCCP::SM100_UTCCP_128dp128bit_1cta; template <> struct Copy_Traits { using ThrID = Layout<_1>; - // logical bit_idx -> tmem_addr using ValID = Layout, Stride>; - - // Map from (src-thr,src-val) to bit using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit using DstLayout = Layout, Stride<_0,_1>>; - // Reference map from (thr,val) to bit using RefLayout = DstLayout; +}; +//////////////////////////////////////////////////////////////////////////////////////////////////// - template - CUTE_HOST_DEVICE friend constexpr - void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); - static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); - SM100_UTCCP_128dp128bit_1cta::copy(src[0], raw_pointer_cast(dst.data())); - } -}; +using SM100::TMEM::UTCCP::SM100_UTCCP_128dp128bit_2cta; template <> struct Copy_Traits { using ThrID = Layout<_2>; - // logical bit_idx -> tmem_addr using ValID = typename Copy_Traits::ValID; - - // Map from (src-thr,src-val) to bit using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit using DstLayout = Layout, Stride<_0, _1>>; - // Reference map from (thr,val) to bit using RefLayout = DstLayout; +}; +//////////////////////////////////////////////////////////////////////////////////////////////////// - template - CUTE_HOST_DEVICE friend constexpr - void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); - static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); - SM100_UTCCP_128dp128bit_2cta::copy(src[0], raw_pointer_cast(dst.data())); - } -}; +using SM100::TMEM::UTCCP::SM100_UTCCP_4dp256bit_1cta; template <> struct Copy_Traits @@ -3485,65 +3598,34 @@ struct Copy_Traits */ using ThrID = Layout<_1>; - // logical bit_idx -> tmem_addr using ValID = Layout, Stride>; - - // Map from (src-thr,src-val) to bit using SrcLayout = Layout>, Stride<_0,Stride<_32,_128>>>; - // Map from (dst-thr,dst-val) to bit using DstLayout = Layout>, Stride<_0,Stride<_32,_128>>>; - // Reference map from (thr,val) to bit using RefLayout = DstLayout; +}; +//////////////////////////////////////////////////////////////////////////////////////////////////// - template - CUTE_HOST_DEVICE friend constexpr - void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); - static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); - SM100_UTCCP_4dp256bit_1cta::copy(src[0], raw_pointer_cast(dst.data())); - } -}; +using SM100::TMEM::UTCCP::SM100_UTCCP_4dp256bit_2cta; template <> struct Copy_Traits { - using ThrID = Layout<_2>; - // logical bit_idx -> tmem_addr using ValID = typename Copy_Traits::ValID; - - // Map from (src-thr,src-val) to bit using SrcLayout = Layout>, Stride<_0,Stride<_32,_128>>>; - // Map from (dst-thr,dst-val) to bit using DstLayout = Layout>, Stride<_0,Stride<_32,_128>>>; - // Reference map from (thr,val) to bit using RefLayout = DstLayout; +}; +//////////////////////////////////////////////////////////////////////////////////////////////////// - template - CUTE_HOST_DEVICE friend constexpr - void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); - static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); - SM100_UTCCP_4dp256bit_2cta::copy(src[0], raw_pointer_cast(dst.data())); - } -}; +using SM100::TMEM::UTCCP::SM100_UTCCP_4x32dp128bit_1cta; template <> struct Copy_Traits @@ -3556,63 +3638,32 @@ struct Copy_Traits // [core_matrix_strided, core_matrix_leading, broadcast] using ValID = Layout, Stride<_DP,_1, _DPx32>>; - - // Map from (src-thr,src-val) to bit using SrcLayout = Layout>, Stride<_0,Stride<_1, _32, _0>>>; - - // Map from (dst-thr,dst-val) to bit using DstLayout = Layout, Stride<_0,_1>>; - // Reference map from (thr,val) to bit using RefLayout = DstLayout; +}; +//////////////////////////////////////////////////////////////////////////////////////////////////// - template - CUTE_HOST_DEVICE friend constexpr - void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); - static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); - SM100_UTCCP_4x32dp128bit_1cta::copy(src[0], raw_pointer_cast(dst.data())); - } -}; +using SM100::TMEM::UTCCP::SM100_UTCCP_4x32dp128bit_2cta; template <> struct Copy_Traits { - using ThrID = Layout<_2>; - // logical bit_idx -> tmem_addr using ValID = typename Copy_Traits::ValID; - - // Map from (src-thr,src-val) to bit using SrcLayout = Layout>, Stride<_0,Stride<_1, _32, _0>>>; - // Map from (dst-thr,dst-val) to bit using DstLayout = Layout, Stride<_0,_1>>; - // Reference map from (thr,val) to bit using RefLayout = DstLayout; +}; +//////////////////////////////////////////////////////////////////////////////////////////////////// - template - CUTE_HOST_DEVICE friend constexpr - void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); - static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); - SM100_UTCCP_4x32dp128bit_2cta::copy(src[0], raw_pointer_cast(dst.data())); - } -}; +using SM100::TMEM::UTCCP::SM100_UTCCP_2x64dp128bitlw0213_1cta; template <> struct Copy_Traits @@ -3625,62 +3676,33 @@ struct Copy_Traits // [core_matrix_strided, core_matrix_leading, broadcast] using ValID = Layout, Stride<_DP,_1, _DPx64>>; - - // Map from (src-thr,src-val) to bit using SrcLayout = Layout>, Stride<_0,Stride<_1, _64, _0>>>; - // Map from (dst-thr,dst-val) to bit using DstLayout = Layout, Stride<_0, _1>>; - // Reference map from (thr,val) to bit using RefLayout = DstLayout; +}; +//////////////////////////////////////////////////////////////////////////////////////////////////// - template - CUTE_HOST_DEVICE friend constexpr - void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); - static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); - SM100_UTCCP_2x64dp128bitlw0213_1cta::copy(src[0], raw_pointer_cast(dst.data())); - } -}; +using SM100::TMEM::UTCCP::SM100_UTCCP_2x64dp128bitlw0213_2cta; template <> struct Copy_Traits { - using ThrID = Layout<_2>; - // logical bit_idx -> tmem_addr using ValID = typename Copy_Traits::ValID; - // Map from (src-thr,src-val) to bit using SrcLayout = Layout>, Stride<_0,Stride<_1, _64, _0>>>; - // Map from (dst-thr,dst-val) to bit using DstLayout = Layout, Stride<_0, _1>>; - // Reference map from (thr,val) to bit using RefLayout = DstLayout; +}; +//////////////////////////////////////////////////////////////////////////////////////////////////// - template - CUTE_HOST_DEVICE friend constexpr - void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); - static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); - SM100_UTCCP_2x64dp128bitlw0213_2cta::copy(src[0], raw_pointer_cast(dst.data())); - } -}; +using SM100::TMEM::UTCCP::SM100_UTCCP_2x64dp128bitlw0123_1cta; template <> struct Copy_Traits @@ -3695,62 +3717,31 @@ struct Copy_Traits using ValID = Layout, Stride<_DP,_1 ,_DPx64,_DPx32>>; - // Map from (src-thr,src-val) to bit using SrcLayout = Layout>, Stride<_0,Stride<_1, _32,_4096,_0>>>; - // Map from (dst-thr,dst-val) to bit using DstLayout = Layout, Stride<_0, _1>>; - // Reference map from (thr,val) to bit using RefLayout = DstLayout; +}; +//////////////////////////////////////////////////////////////////////////////////////////////////// - template - CUTE_HOST_DEVICE friend constexpr - void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); - static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); - SM100_UTCCP_2x64dp128bitlw0123_1cta::copy(src[0], raw_pointer_cast(dst.data())); - } -}; +using SM100::TMEM::UTCCP::SM100_UTCCP_2x64dp128bitlw0123_2cta; template <> struct Copy_Traits { - using ThrID = Layout<_2>; - // logical bit_idx -> tmem_addr using ValID = typename Copy_Traits::ValID; - - // Map from (src-thr,src-val) to bit using SrcLayout = Layout>, Stride<_0,Stride<_1, _32, _4096,_0>>>; - // Map from (dst-thr,dst-val) to bit using DstLayout = Layout, Stride<_0,_1>>; - // Reference map from (thr,val) to bit using RefLayout = DstLayout; - - - template - CUTE_HOST_DEVICE friend constexpr - void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); - static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); - SM100_UTCCP_2x64dp128bitlw0123_2cta::copy(src[0], raw_pointer_cast(dst.data())); - } }; +//////////////////////////////////////////////////////////////////////////////////////////////////// + template CUTE_HOST_DEVICE constexpr @@ -3775,4 +3766,3 @@ make_utccp_copy(CopyOp const&, } // namespace cute -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/atom/copy_traits_sm90_im2col.hpp b/include/cute/atom/copy_traits_sm90_im2col.hpp index beefa63f6c..e4d1e3ffff 100644 --- a/include/cute/atom/copy_traits_sm90_im2col.hpp +++ b/include/cute/atom/copy_traits_sm90_im2col.hpp @@ -647,7 +647,7 @@ make_tma_atom_im2col(CopyOp, gtensor_cwhdn, range_c, range_whdn, - detail::get_swizzle_portion(slayout), + get_swizzle_portion(slayout), tma_layout_vt, lower_corner_whd, upper_corner_whd, diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index a96291e138..08141a0920 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -454,8 +454,6 @@ struct TiledMMA : MMA_Atom { // (M,K) -> (M,K) auto ref_A = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<2>())); - // (athrid,val) -> (M,K) - auto layoutA_TV = thrfrg_A(ref_A); // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) auto atile = make_tile(_, @@ -493,8 +491,6 @@ struct TiledMMA : MMA_Atom { // (N,K) -> (N,K) auto ref_B = make_layout(make_shape(tile_size_mnk<1>(), tile_size_mnk<2>())); - // (bthrid,val) -> (N,K) - auto layoutB_TV = thrfrg_B(ref_B); // (ThrV,(ThrN,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) auto btile = make_tile(_, @@ -1192,6 +1188,7 @@ print_svg(TiledMMA const &mma) { #include #include #include +#include #include #include #include diff --git a/include/cute/atom/mma_traits_sm100.hpp b/include/cute/atom/mma_traits_sm100.hpp index f336eff215..820dc103e1 100644 --- a/include/cute/atom/mma_traits_sm100.hpp +++ b/include/cute/atom/mma_traits_sm100.hpp @@ -37,10 +37,13 @@ #include #include #include -#include // cute::TMEM:: +#include // cute::TMEM:: + #include #include // cute::GMMA:: #include // cute::GMMA:: +#include // UTCCP smem desc + #include // Check that aggregate initialization in .with() initializes all fields @@ -417,6 +420,9 @@ constexpr auto get_utccp_smem_desc_tensor(Tensor const& smem_u namespace UMMA { +// Import TMEM constants +namespace TMEM = cute::TMEM; + enum class TmemAllocMode { // Default allocation mode. // If a TMEM Atom uses a half-subpartition (16DPs), then multiple atoms can be @@ -3053,7 +3059,7 @@ struct MMA_Traits <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4_2x1SM_SS supports types with leq 8bit types"); static_assert(M == 128 || M == 256, "SM100_MMA_F8F6F4_2x1SM_SS M-mode size should be 64 or 128 for 1 CTA cluster MMA."); static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F8F6F4_2x1SM_SS N-mode size should be a multiple of 32 between 32 and 256."); - + using FrgTypeA = UMMA::smem_desc; using FrgTypeB = UMMA::smem_desc; using FrgTypeC = UMMA::tmem_frg_2sm; diff --git a/include/cute/atom/mma_traits_sm89.hpp b/include/cute/atom/mma_traits_sm89.hpp new file mode 100644 index 0000000000..35ad436e22 --- /dev/null +++ b/include/cute/atom/mma_traits_sm89.hpp @@ -0,0 +1,96 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// + +// +#pragma once + +#include +#include +#include +#include + +namespace cute +{ + +namespace { + +// (T32,V4) -> (M16,N8) +using SM80_16x8_Row = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; + +} + +template <> +struct MMA_Traits { + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_16,_8,_32>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape < _4,_2, _2>>, + Stride,Stride<_16,_8,_256>>>; + using BLayout = Layout,Shape <_4, _2>>, + Stride,Stride<_8,_128>>>; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits : +MMA_Traits { + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; +}; + +template <> +struct MMA_Traits : +MMA_Traits { + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; +}; + +template <> +struct MMA_Traits : +MMA_Traits { + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; +}; + +} // end namespace cute diff --git a/include/cute/atom/mma_traits_sm90_gmma.hpp b/include/cute/atom/mma_traits_sm90_gmma.hpp index e688a7e6a8..e1c3bb4034 100644 --- a/include/cute/atom/mma_traits_sm90_gmma.hpp +++ b/include/cute/atom/mma_traits_sm90_gmma.hpp @@ -322,7 +322,11 @@ struct DescriptorIterator CUTE_HOST_DEVICE constexpr DescriptorIterator operator+(Index const& offset) const { - return { GmmaDescriptor{desc_ + uint64_t(offset)} }; + // Use 32bit calculation rather than 64 bit calculation as we only update the part of desc + GmmaDescriptor ret; + ret.reg32_[0] = desc_.reg32_[0] + uint32_t(offset); + ret.reg32_[1] = desc_.reg32_[1]; + return { ret }; } }; diff --git a/include/cute/config.hpp b/include/cute/config.hpp index 3ac4c1024f..f3d72f257f 100644 --- a/include/cute/config.hpp +++ b/include/cute/config.hpp @@ -151,6 +151,16 @@ # include #endif +// +// Type +// + +#if defined(__CUDACC_RTC__) +# include +#else +# include +#endif + // // Debugging utilities // diff --git a/include/cute/container/tuple.hpp b/include/cute/container/tuple.hpp index ed4f8c8c23..9a13e951be 100644 --- a/include/cute/container/tuple.hpp +++ b/include/cute/container/tuple.hpp @@ -53,8 +53,8 @@ // but do _not_ include references like int& or float&. // (See std::tie for an example of a tuple of references.) // -// Standard-layout types preserve ABI across host-device boundaries. -// They are safe to use as device kernel parameters. +// Standard-layout types preserve ABI across host-device boundaries. They are safe to use as device kernel parameters. +// The standard-layout requirement prevents a more common EBO-based implemented of cute::tuple. // // The cute::tuple is also simplified over the implementations in std::, cuda::std::, and thrust:: by ignoring much of // the conversion SFINAE, special overloading, and avoiding cvref template types. @@ -64,12 +64,15 @@ namespace cute { -namespace detail +template +struct tuple; + +namespace eso { // ESO stands for "empty structure optimization." -// We use this technique to ensure that cute::tuple -// doesn't waste space storing template arguments that have no data (like integral_constant). +// We use this technique to ensure that cute::tuple doesn't waste space +// storing template arguments that have no data (like integral_constant). // Empty types in the template argument list are not even constructed, // and do not have unique element addresses. Calling `get` // constructs and returns an instance of an empty type on demand. @@ -133,88 +136,92 @@ struct ESO { }; // Get Nth value from ESO -template -CUTE_HOST_DEVICE constexpr -cute::enable_if_t>>::value, - cute::tuple_element_t>> -getv(ESO const&) -{ - return {}; -} - -template +template CUTE_HOST_DEVICE constexpr -cute::enable_if_t>>::value, - cute::tuple_element_t> const&> -getv(ESO const& s) +R +getr(S&& s) noexcept { if constexpr (N == 0) { - return static_cast(s.first_); + return static_cast(s).first_; } else { - return getv(s.rest_); + return getr(static_cast(s).rest_); } + CUTE_GCC_UNREACHABLE; } -template +// Compilers disagree on decltype(auto), so these implementations avoid it at cost +template CUTE_HOST_DEVICE constexpr -cute::enable_if_t>>::value, - cute::tuple_element_t> &> -getv(ESO& s) +cute::conditional_t>>::value, + cute::tuple_element_t>, + cute::tuple_element_t> const&> +getv_cr(ESO const& s) noexcept { - if constexpr (N == 0) { - return static_cast(s.first_); + if constexpr (cute::is_empty>>::value) { + return {}; } else { - return getv(s.rest_); + return getr> const&, N>(s); } + CUTE_GCC_UNREACHABLE; } -template +template CUTE_HOST_DEVICE constexpr -cute::enable_if_t>>::value, - cute::tuple_element_t> &&> -getv(ESO&& s) +cute::conditional_t>>::value, + cute::tuple_element_t>, + cute::tuple_element_t> &> +getv_r(ESO& s) noexcept { - if constexpr (N == 0) { - return static_cast(s.first_); + if constexpr (cute::is_empty>>::value) { + return {}; } else { - return getv(static_cast&&>(s.rest_)); + return getr> &, N>(s); } + CUTE_GCC_UNREACHABLE; } -template +template CUTE_HOST_DEVICE constexpr -auto -findt(ESO const& t) noexcept -{ - if constexpr (cute::is_same_v) { - return C{}; - } else - if constexpr (sizeof...(Rest) == 0) { - return C{}; - } else - if constexpr (IsRestEmpty) { - return cute::detail::findt(ESO_t{}); +cute::conditional_t>>::value, + cute::tuple_element_t>, + cute::tuple_element_t> &&> +getv_rr(ESO&& s) noexcept +{ + if constexpr (cute::is_empty>>::value) { + return {}; } else { - return cute::detail::findt(t.rest_); + return getr> &&, N>(static_cast&&>(s)); } + CUTE_GCC_UNREACHABLE; } -} // end namespace detail +} // end namespace eso template -struct tuple : detail::ESO_t +struct tuple : eso::ESO_t { CUTE_HOST_DEVICE constexpr tuple() {} CUTE_HOST_DEVICE constexpr - tuple(T const&... t) : detail::ESO_t(t...) {} + tuple(T const&... t) : eso::ESO_t(t...) {} }; template <> struct tuple<> {}; +// +// make_tuple (value-based implementation) +// + +template +CUTE_HOST_DEVICE constexpr +tuple +make_tuple(T const&... t) +{ + return {t...}; +} + // Returns the element in the ith position of the tuple template CUTE_HOST_DEVICE constexpr @@ -222,7 +229,7 @@ decltype(auto) get(tuple const& t) noexcept { static_assert(I < sizeof...(T), "Index out of range"); - return detail::getv(t); + return eso::getv_cr(t); } template @@ -231,7 +238,7 @@ decltype(auto) get(tuple& t) noexcept { static_assert(I < sizeof...(T), "Index out of range"); - return detail::getv(t); + return eso::getv_r(t); } template @@ -240,22 +247,22 @@ decltype(auto) get(tuple&& t) noexcept { static_assert(I < sizeof...(T), "Index out of range"); - return detail::getv(static_cast&&>(t)); + return eso::getv_rr(static_cast&&>(t)); } -// Returns the position of type X (as a static integer) in the tuple -// type's argument list. X must be unique in the argument list. +// Returns the first position of type X (as a static integer) in the tuple +// type's argument list. template CUTE_HOST_DEVICE constexpr auto -find(tuple const& t) noexcept +find(tuple const&) noexcept { - return detail::findt(t); + return cute::C...>>{}; } // // Custom is_tuple trait simply checks the existence of tuple_size -// and assumes std::get(.), std::tuple_element +// and assumes get(.), tuple_element // namespace detail { @@ -269,19 +276,7 @@ template struct is_tuple : decltype(detail::has_tuple_size((T*)0)) {}; template -constexpr bool is_tuple_v = cute::is_tuple::value; - -// -// make_tuple (value-based implementation) -// - -template -CUTE_HOST_DEVICE constexpr -tuple -make_tuple(T const&... t) -{ - return {t...}; -} +static constexpr bool is_tuple_v = cute::is_tuple::value; // // tuple_cat concatenates multiple cute::tuple into a single cute::tuple, diff --git a/include/cute/container/type_list.hpp b/include/cute/container/type_list.hpp index b8ac5f0de5..dfffbe251f 100644 --- a/include/cute/container/type_list.hpp +++ b/include/cute/container/type_list.hpp @@ -31,6 +31,7 @@ #pragma once #include // CUTE_HOST_DEVICE, CUTE_STL_NAMESPACE +#include namespace cute { @@ -39,11 +40,35 @@ template struct type_list {}; // get for type_list -// requires tuple_element_t> to have std::is_default_constructible +// Get an instance of the Ith type in the pack T... +// Requires tuple_element_t> to have std::is_default_constructible template CUTE_HOST_DEVICE constexpr CUTE_STL_NAMESPACE::tuple_element_t> -get(type_list const& t) noexcept { +get(type_list const&) noexcept { + return {}; +} + +// Find the index of the first true in the pack B... +template +struct find_true { + CUTE_HOST_DEVICE static constexpr size_t find() { + size_t i = 0; + (void) ((B ? true : (++i, false)) || ...); + return i; + } + static constexpr size_t value = find(); +}; + +template +static constexpr size_t find_true_v = find_true::value; + +// find for type_list +// Finds the first position of type X (as a static integer) in the T... pack +template +CUTE_HOST_DEVICE constexpr +CUTE_STL_NAMESPACE::integral_constant...>> +find(type_list const&) noexcept { return {}; } @@ -69,9 +94,8 @@ struct tuple_size> template struct tuple_element> -{ - using type = typename CUTE_STL_NAMESPACE::tuple_element>::type; -}; + : CUTE_STL_NAMESPACE::tuple_element> +{}; } // end namespace std @@ -94,9 +118,8 @@ struct tuple_size> template struct tuple_element> -{ - using type = typename CUTE_STL_NAMESPACE::tuple_element>::type; -}; + : CUTE_STL_NAMESPACE::tuple_element> +{}; } // end namespace std #endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp index 4ee901ada0..3f02a41d44 100644 --- a/include/cute/layout.hpp +++ b/include/cute/layout.hpp @@ -834,6 +834,8 @@ coalesce_x(Layout const& layout) } else { return detail::bw_coalesce(flat_shape, flat_stride, get(flat_shape), get(flat_stride)); } + + CUTE_GCC_UNREACHABLE; } // Apply coalesce_x at the terminals of trg_profile @@ -903,6 +905,8 @@ coalesce(Shape const& shape) } else { return append(init, a); // Can't coalesce, so append } + + CUTE_GCC_UNREACHABLE; }); } @@ -1026,7 +1030,7 @@ template CUTE_HOST_DEVICE constexpr auto -composition_impl(LShape const& lhs_shape, LStride const& lhs_stride, +composition_impl(LShape const& lhs_shape, [[maybe_unused]] LStride const& lhs_stride, RShape const& rhs_shape, RStride const& rhs_stride) { if constexpr (is_tuple::value) { // Right-distributivity of Layout composition for RHS tuple @@ -1063,7 +1067,7 @@ composition_impl(LShape const& lhs_shape, LStride const& lhs_stride, auto rest_stride = get<3>(init); auto curr_shape = get(lhs_shape); - auto curr_stride = get(lhs_stride); + [[maybe_unused]] auto curr_stride = get(lhs_stride); // Strong divisibility condition -- requires composition to be statically verifiable. //CUTE_STATIC_ASSERT_V(((rest_stride % curr_shape) == Int<0>{}) or (rest_stride < curr_shape), "Stride Divisibility Condition"); @@ -1105,6 +1109,8 @@ composition_impl(LShape const& lhs_shape, LStride const& lhs_stride, rest_shape / new_shape, next_stride); } + + CUTE_GCC_UNREACHABLE; }); if constexpr (tuple_size::value == 0) { @@ -1289,6 +1295,8 @@ right_inverse(Layout const& layout) } else { return init; } + + CUTE_GCC_UNREACHABLE; }); return coalesce(make_layout(result_shape, result_stride)); @@ -1344,9 +1352,11 @@ left_inverse(Layout const& layout) return make_tuple(append(result_shape, istride / size(result_shape)), append(result_stride, get(preprod_shape))); } + + CUTE_GCC_UNREACHABLE; }); - return coalesce(make_layout(append(result_shape, get(lshape)), + return coalesce(make_layout(append(result_shape, get(lshape)), result_stride)); } @@ -1499,7 +1509,7 @@ nullspace(Layout const& layout) { auto flat_layout = flatten(layout); - auto iseq = detail::nullspace_seq<0>(flat_layout.stride(), seq<>{}); + [[maybe_unused]] auto iseq = detail::nullspace_seq<0>(flat_layout.stride(), seq<>{}); if constexpr (iseq.size() == 0) { return Layout<_1,_0>{}; // Empty case, nothing found diff --git a/include/cute/numeric/arithmetic_tuple.hpp b/include/cute/numeric/arithmetic_tuple.hpp index 33076378ea..60a4ff4abc 100644 --- a/include/cute/numeric/arithmetic_tuple.hpp +++ b/include/cute/numeric/arithmetic_tuple.hpp @@ -84,6 +84,8 @@ as_arithmetic_tuple(T const& t) { } else { return t; } + + CUTE_GCC_UNREACHABLE; } // diff --git a/include/cute/numeric/math.hpp b/include/cute/numeric/math.hpp index 83dcd4e6e5..147458b85d 100644 --- a/include/cute/numeric/math.hpp +++ b/include/cute/numeric/math.hpp @@ -57,7 +57,7 @@ template >(t) < static_cast>(u) ? t : u; } template ,Offset,Layout> const& layout) // Utilities // -namespace detail { - // Get just the Swizzle part of a composed layout. template CUTE_HOST_DEVICE constexpr @@ -167,8 +165,6 @@ get_nonswizzle_portion(Layout const& slayout) return slayout; } -} // namespace detail - // // Slice a Swizzled ComposedLayout // diff --git a/include/cute/tensor_impl.hpp b/include/cute/tensor_impl.hpp index 9c1a0b4420..0d9144884b 100644 --- a/include/cute/tensor_impl.hpp +++ b/include/cute/tensor_impl.hpp @@ -381,6 +381,8 @@ struct MakeTensor return Tensor(); } } + + CUTE_GCC_UNREACHABLE; } }; diff --git a/include/cutlass/arch/arch.h b/include/cutlass/arch/arch.h index 74ab834f49..c1032f0b0c 100644 --- a/include/cutlass/arch/arch.h +++ b/include/cutlass/arch/arch.h @@ -43,7 +43,7 @@ namespace cutlass { namespace arch { constexpr int sm100_smem_capacity_bytes = 232448; -constexpr int sm120_smem_capacity_bytes = 102400; +constexpr int sm120_smem_capacity_bytes = 101376; #if defined(__NVCC__) || defined(__CUDACC_RTC__) || (defined(__clang__) && (defined(__CUDA__) || defined(CUTLASS_ENABLE_SYCL))) diff --git a/include/cutlass/arch/barrier.h b/include/cutlass/arch/barrier.h index 249191371a..6280430d95 100644 --- a/include/cutlass/arch/barrier.h +++ b/include/cutlass/arch/barrier.h @@ -53,6 +53,9 @@ #define CUTLASS_ARCH_TCGEN_ENABLED 1 #endif +#if (defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED)) +#define CUTLASS_ARCH_TCGEN_ENABLED 1 +#endif namespace cutlass { /// @brief @@ -389,7 +392,7 @@ struct ClusterBarrier { // // Static Versions // - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static void init(ValueType const* smem_ptr, uint32_t arrive_count) { #if CUDA_BARRIER_ENABLED uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); @@ -406,7 +409,7 @@ struct ClusterBarrier { } // Static version of wait - in case we don't want to burn a register - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static void wait(ValueType const* smem_ptr, uint32_t phase) { #if CUDA_BARRIER_ENABLED uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); @@ -430,7 +433,7 @@ struct ClusterBarrier { #endif } - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static bool test_wait(ValueType const* smem_ptr, uint32_t phase, uint32_t pred) { #if CUDA_BARRIER_ENABLED uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); @@ -455,7 +458,7 @@ struct ClusterBarrier { return 0; } - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static bool try_wait(ValueType const* smem_ptr, uint32_t phase) { #if CUDA_BARRIER_ENABLED uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); @@ -479,7 +482,7 @@ struct ClusterBarrier { } // Static Predicated version of the above - in case we know the address. - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static void arrive(ValueType const* smem_ptr, uint32_t cta_id, uint32_t pred) { #if CUDA_BARRIER_ENABLED uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); @@ -501,7 +504,7 @@ struct ClusterBarrier { } // Barrier arrive on local smem - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static void arrive(ValueType const* smem_ptr) { #if CUDA_BARRIER_ENABLED uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); @@ -517,7 +520,7 @@ struct ClusterBarrier { #endif } - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static void invalidate(ValueType const* smem_ptr) { #if CUDA_BARRIER_ENABLED uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); @@ -578,7 +581,7 @@ struct ClusterTransactionBarrier : public ClusterBarrier { // // Performs an arrive operation + expected transaction bytes increment - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static void arrive_and_expect_tx(ValueType const* smem_ptr, uint32_t transaction_bytes) { #if CUDA_BARRIER_ENABLED uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); @@ -595,7 +598,7 @@ struct ClusterTransactionBarrier : public ClusterBarrier { } // Performs an arrive operation + expected transaction bytes increment for a remote cta_id in a Cluster - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static void arrive_and_expect_tx( ValueType const* smem_ptr, uint32_t transaction_bytes, uint32_t cta_id, uint32_t pred) { #if CUDA_BARRIER_ENABLED @@ -616,7 +619,7 @@ struct ClusterTransactionBarrier : public ClusterBarrier { } // Performs an expected transaction bytes increment without doing an arrive operation - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static void expect_transaction(ValueType const* smem_ptr, uint32_t transaction_bytes) { #if CUDA_BARRIER_ENABLED uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); @@ -633,7 +636,7 @@ struct ClusterTransactionBarrier : public ClusterBarrier { } // Performs an expected transaction bytes decrement without doing an arrive operation - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static void complete_transaction( ValueType const* smem_ptr, uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred = 1) { #if CUDA_BARRIER_ENABLED @@ -728,7 +731,7 @@ void fence_view_async_shared() { } // Arrive on completion of in-flight cp.async operations issued by the calling thread -CUTLASS_DEVICE +CUTLASS_HOST_DEVICE void cpasync_barrier_arrive(uint64_t const* smem_ptr) { #if CUDA_BARRIER_ENABLED uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); @@ -745,7 +748,7 @@ void cpasync_barrier_arrive(uint64_t const* smem_ptr) { } // Arrive on completion of in-flight cp.async operations issued by the calling thread (noinc) -CUTLASS_DEVICE +CUTLASS_HOST_DEVICE void cpasync_barrier_arrive_noinc(uint64_t const* smem_ptr) { #if CUDA_BARRIER_ENABLED uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); @@ -764,7 +767,7 @@ void cpasync_barrier_arrive_noinc(uint64_t const* smem_ptr) { //////////////////////////////////////////////////////////////////////////////////////////////////// -CUTLASS_DEVICE +CUTLASS_HOST_DEVICE void umma_arrive(uint64_t const* smem_ptr) { #if defined(CUTLASS_ARCH_TCGEN_ENABLED) uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr); @@ -779,7 +782,7 @@ void umma_arrive(uint64_t const* smem_ptr) { } //UMMA arrive for MMA_2x1SM -CUTLASS_DEVICE +CUTLASS_HOST_DEVICE void umma_arrive_2x1SM(uint64_t const* smem_ptr) { #if defined(CUTLASS_ARCH_TCGEN_ENABLED) uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr); @@ -794,7 +797,7 @@ void umma_arrive_2x1SM(uint64_t const* smem_ptr) { } // UMMA arrive for MMA_1sm + TMA_LOAD_MULTICAST combination -CUTLASS_DEVICE +CUTLASS_HOST_DEVICE void umma_arrive_multicast(uint64_t const* smem_ptr, uint16_t cta_mask) { #if defined(CUTLASS_ARCH_TCGEN_ENABLED) uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr); @@ -812,7 +815,7 @@ void umma_arrive_multicast(uint64_t const* smem_ptr, uint16_t cta_mask) { } // UMMA arrive for MMA_2x1SM + TMA_LOAD_MULTICAST combination -CUTLASS_DEVICE +CUTLASS_HOST_DEVICE void umma_arrive_multicast_2x1SM(uint64_t const* smem_ptr, uint16_t cta_mask) { #if defined(CUTLASS_ARCH_TCGEN_ENABLED) uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr); @@ -824,14 +827,14 @@ void umma_arrive_multicast_2x1SM(uint64_t const* smem_ptr, uint16_t cta_mask) { : :"r"(bar_intptr), "h"(cta_mask)); } -#else +#elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif } // Temporary solution for sparse kernel. // Will remove this when we done tightly elect_one wrap. -CUTLASS_DEVICE +CUTLASS_HOST_DEVICE void umma_arrive_multicast_no_elect(uint64_t const* smem_ptr, uint16_t cta_mask) { #if defined(CUTLASS_ARCH_TCGEN_ENABLED) uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr); @@ -850,7 +853,7 @@ void umma_arrive_multicast_no_elect(uint64_t const* smem_ptr, uint16_t cta_mask) // Temporary solution for sparse kernel. // UMMA arrive for MMA_2x1SM + TMA_LOAD_MULTICAST combination -CUTLASS_DEVICE +CUTLASS_HOST_DEVICE void umma_arrive_multicast_2x1SM_no_elect(uint64_t const* smem_ptr, uint16_t cta_mask) { #if defined(CUTLASS_ARCH_TCGEN_ENABLED) uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr); @@ -868,7 +871,7 @@ void umma_arrive_multicast_2x1SM_no_elect(uint64_t const* smem_ptr, uint16_t cta } // Always arrive on even SM of collaborating 2 SMs. -CUTLASS_DEVICE +CUTLASS_HOST_DEVICE void umma_arrive_2x1SM_sm0(uint64_t const* smem_ptr) { #if defined(CUTLASS_ARCH_TCGEN_ENABLED) uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr) & cute::Sm100MmaPeerBitMask; @@ -879,7 +882,7 @@ void umma_arrive_2x1SM_sm0(uint64_t const* smem_ptr) { : : "r"(bar_intptr)); -#else +#elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif } diff --git a/include/cutlass/arch/config.h b/include/cutlass/arch/config.h index 1dd27f78db..e5daf8292b 100644 --- a/include/cutlass/arch/config.h +++ b/include/cutlass/arch/config.h @@ -92,6 +92,14 @@ #define CUTLASS_ARCH_MMA_SM100A_ENABLED 1 #endif + // SM100f + #if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) + #define CUTLASS_ARCH_MMA_SM100F_SUPPORTED 1 + #endif + + #if (!defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) && CUDA_ARCH_FAMILY(1000)) + #define CUTLASS_ARCH_MMA_SM100F_ENABLED CUTLASS_ARCH_MMA_SM100F_SUPPORTED + #endif #endif #endif @@ -109,6 +117,14 @@ #define CUTLASS_ARCH_MMA_SM101A_ENABLED 1 #endif + // SM101f + #if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) + #define CUTLASS_ARCH_MMA_SM101F_SUPPORTED 1 + #endif + + #if (!defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) && CUDA_ARCH_FAMILY(1010)) + #define CUTLASS_ARCH_MMA_SM101F_ENABLED CUTLASS_ARCH_MMA_SM101F_SUPPORTED + #endif #endif #endif @@ -124,12 +140,21 @@ #define CUTLASS_ARCH_MMA_SM120A_ENABLED 1 #endif + // SM120f + #if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) + #define CUTLASS_ARCH_MMA_SM120F_SUPPORTED 1 + #endif + + #if (!defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) && CUDA_ARCH_FAMILY(1200)) + #define CUTLASS_ARCH_MMA_SM120F_ENABLED CUTLASS_ARCH_MMA_SM120F_SUPPORTED + #endif #endif #endif -#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ - defined(CUTLASS_ARCH_MMA_SM120A_ENABLED)) +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120F_ENABLED)) # define CUTLASS_ARCH_CLC_ENABLED #endif diff --git a/include/cutlass/arch/grid_dependency_control.h b/include/cutlass/arch/grid_dependency_control.h index ae66de279d..e7defb5dbb 100644 --- a/include/cutlass/arch/grid_dependency_control.h +++ b/include/cutlass/arch/grid_dependency_control.h @@ -53,6 +53,20 @@ #endif #endif +#ifndef CUTLASS_GDC_ENABLED + #if(CUDA_BARRIER_ENABLED && \ + defined(CUTLASS_ENABLE_GDC_FOR_SM100) && \ + defined(__CUDA_ARCH__) && \ + ((__CUDA_ARCH__ == 1000 &&\ + (defined(__CUDA_ARCH_FEAT_SM100_ALL) || CUDA_ARCH_FAMILY(1000))) || \ + (__CUDA_ARCH__ == 1010 &&\ + (defined(__CUDA_ARCH_FEAT_SM101_ALL) || CUDA_ARCH_FAMILY(1010))) || \ + (__CUDA_ARCH__ == 1200 &&\ + (defined(__CUDA_ARCH_FEAT_SM120_ALL) || CUDA_ARCH_FAMILY(1200))))) + #define CUTLASS_GDC_ENABLED + #endif +#endif + namespace cutlass { namespace arch { @@ -84,6 +98,5 @@ static constexpr bool IsGdcGloballyEnabled = true; static constexpr bool IsGdcGloballyEnabled = false; #endif - } // namespace arch } // namespace cutlass diff --git a/include/cutlass/arch/memory_sm75.h b/include/cutlass/arch/memory_sm75.h index 9192687763..040f707436 100644 --- a/include/cutlass/arch/memory_sm75.h +++ b/include/cutlass/arch/memory_sm75.h @@ -60,7 +60,7 @@ CUTLASS_DEVICE void ldsm(Array & D, void const* ptr); ///////////////////////////////////////////////////////////////////////////////////////////////// /// CUTLASS helper to get SMEM pointer -CUTLASS_DEVICE unsigned cutlass_get_smem_pointer(void *ptr) { +CUTLASS_HOST_DEVICE unsigned cutlass_get_smem_pointer(void *ptr) { return cute::cast_smem_ptr_to_uint(ptr); } diff --git a/include/cutlass/arch/reg_reconfig.h b/include/cutlass/arch/reg_reconfig.h index 557643e5e6..a65ee3281f 100644 --- a/include/cutlass/arch/reg_reconfig.h +++ b/include/cutlass/arch/reg_reconfig.h @@ -47,6 +47,14 @@ #define CUDA_CTA_RECONFIG_ACTIVATED 1 #endif + #if defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 12 && ( \ + (__CUDA_ARCH__ == 1000 && CUDA_ARCH_FAMILY(1000)) \ + || (__CUDA_ARCH__ == 1010 && CUDA_ARCH_FAMILY(1010)) \ + || (__CUDA_ARCH__ == 1200 && CUDA_ARCH_FAMILY(1200)) \ + ) + #define CUDA_CTA_RECONFIG_ACTIVATED 1 + #endif + #endif namespace cutlass { diff --git a/include/cutlass/arch/wmma.h b/include/cutlass/arch/wmma.h index 2cafa51085..9cb9c04f95 100644 --- a/include/cutlass/arch/wmma.h +++ b/include/cutlass/arch/wmma.h @@ -34,9 +34,6 @@ #pragma once -// CUTLASS WMMA does not support clang at present. -#if !(defined(__clang__) && defined(__CUDA__)) - #if (__CUDACC_VER_MAJOR__ >= 9) #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700)) #define CUTLASS_ARCH_WMMA_ENABLED @@ -58,8 +55,6 @@ #endif #endif -#endif //!(defined(__clang__) && defined(__CUDA__)) - #if defined(CUTLASS_ARCH_WMMA_ENABLED) #include diff --git a/include/cutlass/array.h b/include/cutlass/array.h index e1e182827f..ce33110aa4 100644 --- a/include/cutlass/array.h +++ b/include/cutlass/array.h @@ -986,6 +986,21 @@ struct multiply_add, Array, Array> { return result; } + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b, T const &scalar) const { + + Array result; + multiply_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], b[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE Array operator()(Array const &a, T const &scalar_b, T const &scalar_c) const { diff --git a/include/cutlass/conv/collective/builders/sm100_umma_builder.inl b/include/cutlass/conv/collective/builders/sm100_umma_builder.inl index db1f7dae0a..9a9d4cb4e9 100644 --- a/include/cutlass/conv/collective/builders/sm100_umma_builder.inl +++ b/include/cutlass/conv/collective/builders/sm100_umma_builder.inl @@ -168,7 +168,7 @@ private: // Calculate SMEM matrix A and B buffers' pipeline stages static constexpr uint32_t AccumulatorPipelineStageCount = 2; - static constexpr uint32_t SchedulerPipelineStageCount = 2; + static constexpr uint32_t SchedulerPipelineStageCount = 1; static constexpr uint32_t CLCResponseSize = 16; // AccumulatorPipeline = PipelineUmmaAsync @@ -179,8 +179,6 @@ private: static constexpr auto LoadOrderBarrierStorage = sizeof(typename cutlass::OrderedSequenceBarrier<1,2>::SharedStorage); // CLC (scheduler) response static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * CLCResponseSize; - // CLC Throttle pipeline storage - static constexpr auto CLCThrottlePipelineStorage = sizeof(typename cutlass::PipelineAsync::SharedStorage); // Tmem dealloc static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier); // Tmem ptr storage @@ -190,7 +188,6 @@ private: CLCPipelineStorage + LoadOrderBarrierStorage + TmemDeallocStorage + - CLCThrottlePipelineStorage + CLCResponseStorage + TmemBasePtrsStorage); // Reduce SMEM capacity available for buffers considering barrier allocations. @@ -204,7 +201,12 @@ private: constexpr static int NumSpatialDimensions = detail::gmem_layout_tags_to_spatial_dims(); using DispatchPolicy = cutlass::conv::MainloopSm100TmaUmmaWarpSpecializedImplicitGemm< - ConvOp, PipelineStages, NumSpatialDimensions, ClusterShape_MNK>; + ConvOp, + PipelineStages, + NumSpatialDimensions, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape_MNK>; public: using CollectiveOp = cutlass::conv::collective::CollectiveConv< diff --git a/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp b/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp index dc75b988d5..278f69f93f 100644 --- a/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp +++ b/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp @@ -28,9 +28,7 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ -// -// #pragma once @@ -66,6 +64,8 @@ template < conv::Operator ConvOp, int Stages, int NumSpatialDims, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, class ClusterShape, // Static cluster shape or dynamic (int, int, _1) class TileShapeMNKL_, // (MmaAtomShapeM, MmaAtomShapeN, TileK, optional: TileL) class ElementA_, @@ -75,7 +75,12 @@ template < class TileTraitsB_> struct CollectiveConv< MainloopSm100TmaUmmaWarpSpecializedImplicitGemm< - ConvOp, Stages, NumSpatialDims, ClusterShape>, + ConvOp, + Stages, + NumSpatialDims, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, TileShapeMNKL_, ElementA_, ElementB_, @@ -87,7 +92,12 @@ struct CollectiveConv< // Type Aliases // using DispatchPolicy = MainloopSm100TmaUmmaWarpSpecializedImplicitGemm< - ConvOp, Stages, NumSpatialDims, ClusterShape>; + ConvOp, + Stages, + NumSpatialDims, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; using TileShape = decltype(cute::take<0,3>(TileShapeMNKL_{})); // (MmaAtomShapeM, MmaAtomShapeN, TileK) using ElementA = ElementA_; using ElementB = ElementB_; @@ -348,10 +358,12 @@ struct CollectiveConv< // Constructor // CUTLASS_DEVICE - CollectiveConv(Params const& params) { + CollectiveConv(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) { if constexpr (IsDynamicCluster) { - dim3 cs = cute::cluster_shape(); - const bool is_fallback_cluster = (cs.x == params.cluster_shape_fallback.x && cs.y == params.cluster_shape_fallback.y); + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; } @@ -648,28 +660,14 @@ struct CollectiveConv< } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance - CUTLASS_DEVICE static void - prefetch_tma_descriptors(Params const& mainloop_params) { - if constexpr (IsDynamicCluster) { - dim3 cs = cute::cluster_shape(); - const bool is_fallback_cluster = (cs.x == mainloop_params.cluster_shape_fallback.x && cs.y == mainloop_params.cluster_shape_fallback.y); - if (is_fallback_cluster) { - cute::prefetch_tma_descriptor(mainloop_params.tma_load_a_fallback.get_tma_descriptor()); - cute::prefetch_tma_descriptor(mainloop_params.tma_load_b_fallback.get_tma_descriptor()); - } - else { - cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); - cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); - } - } - else { - cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); - cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); - } + CUTLASS_DEVICE void + prefetch_tma_descriptors() { + cute::prefetch_tma_descriptor(observed_tma_load_a_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_b_->get_tma_descriptor()); } /// Construct A Single Stage's Accumulator Shape - CUTLASS_DEVICE auto + CUTLASS_DEVICE static auto partition_accumulator_shape() { auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) @@ -794,11 +792,10 @@ struct CollectiveConv< Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) - auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape()); - Layout cta_layout_mnk = make_layout(cluster_shape); + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); - int block_rank_in_cluster = cute::block_rank_in_cluster(); - auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); // Project the cta_layout for tma_a along the n-modes auto [tAgA_mk, tAsA] = tma_partition(*observed_tma_load_a_, @@ -890,7 +887,7 @@ struct CollectiveConv< } CUTLASS_DEVICE auto - mma_init(TensorStorage& shared_tensors) { + mma_init(TensorStorage& shared_tensors) const { Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) @@ -909,6 +906,9 @@ struct CollectiveConv< typename Params::TMA_A const* observed_tma_load_a_ = nullptr; typename Params::TMA_B const* observed_tma_load_b_ = nullptr; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/dispatch_policy.hpp b/include/cutlass/conv/dispatch_policy.hpp index b4bf8a5382..d569cb1c3e 100644 --- a/include/cutlass/conv/dispatch_policy.hpp +++ b/include/cutlass/conv/dispatch_policy.hpp @@ -86,7 +86,10 @@ struct MainloopSm90TmaGmmaWarpSpecializedImplicitGemm { // SM100 tensor op kernel schedule -struct KernelImplicitTmaWarpSpecializedSm100 { }; +struct KernelImplicitTmaWarpSpecializedSm100 { + static constexpr int SchedulerPipelineStageCount = 0; + static constexpr int AccumulatorPipelineStageCount = 0; +}; // Pseudo-policies for builder auto override that dispatches to the KernelImplicitTmaWarpSpecializedSm100 // but for opting into 1 or 2 SM atoms @@ -96,11 +99,23 @@ struct KernelImplicitTmaWarpSpecialized2SmSm100 : KernelImplicitTmaWarpSpecializ struct KernelStridedDgradTmaWs1SmSm100 { }; struct KernelStridedDgradTmaWs2SmSm100 { }; +// Policy for implicit gemm kernel +template< + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_ +> +struct KernelScheduleImplicitTmaWarpSpecializedSm100 : KernelImplicitTmaWarpSpecializedSm100 { + static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; +}; + // n-buffer in smem (Blackwell TMA), pipelined with Blackwell UMMA and TMA, fprop template< conv::Operator ConvOp_, int Stages_, int NumSpatialDimensions_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, class ClusterShape_ = cute::Shape,cute::C<1>,cute::C<1>> > struct MainloopSm100TmaUmmaWarpSpecializedImplicitGemm { @@ -109,7 +124,7 @@ struct MainloopSm100TmaUmmaWarpSpecializedImplicitGemm { static constexpr Operator ConvOp = ConvOp_; using ClusterShape = ClusterShape_; using ArchTag = arch::Sm100; - using Schedule = KernelImplicitTmaWarpSpecializedSm100; + using Schedule = KernelScheduleImplicitTmaWarpSpecializedSm100; static_assert(NumSpatialDimensions >= 1); }; diff --git a/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp b/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp index 90236e1fd9..0874d8f8ab 100644 --- a/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp @@ -29,8 +29,6 @@ * **************************************************************************************************/ - - #pragma once #include "cutlass/cutlass.h" @@ -110,7 +108,8 @@ class ConvUniversal< static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; // TileID scheduler // CLC pipeline depth determines how many waves (stages-1) the scheduler can race ahead - static constexpr uint32_t SchedulerPipelineStageCount = 2; + static constexpr uint32_t SchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount; + static constexpr uint32_t AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; using TileSchedulerTag = TileSchedulerTag_; using TileScheduler = typename cutlass::gemm::kernel::detail::TileSchedulerSelector< @@ -135,7 +134,6 @@ class ConvUniversal< static constexpr uint32_t NumFixupBarriers = 1; // Pipelines and pipeline states - static constexpr uint32_t AccumulatorPipelineStageCount = SchedulerPipelineStageCount; static constexpr uint32_t CLCResponseSize = sizeof(typename TileScheduler::CLCResponse); // Pipeline and pipeline state types @@ -157,10 +155,6 @@ class ConvUniversal< using CLCPipelineState = cutlass::PipelineDetail::PipelineCLCFetchAsyncPipelineState; using CLCPipelineSharedStorage = cutlass::PipelineDetail::PipelineCLCFetchAsyncSharedStorage; - using CLCThrottlePipeline = cutlass::PipelineAsync; - using CLCThrottlePipelineState = cutlass::PipelineDetail::PipelineAsyncPipelineState; - using CLCThrottlePipelineSharedStorage = cutlass::PipelineDetail::PipelineAsyncSharedStorage; - using TmemAllocator = cute::conditional_t(typename TiledMma::ThrLayoutVMNK{})) == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>; @@ -172,14 +166,12 @@ class ConvUniversal< using LoadOrderBarrierStorage = typename LoadOrderBarrier::SharedStorage; using CLCPipelineStorage = CLCPipelineSharedStorage; using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; - using CLCThrottlePipelineStorage = CLCThrottlePipelineSharedStorage; alignas(16) MainloopPipelineStorage mainloop; alignas(16) EpiLoadPipelineStorage epi_load; alignas(16) LoadOrderBarrierStorage load_order; alignas(16) CLCPipelineStorage clc; alignas(16) AccumulatorPipelineStorage accumulator; - alignas(16) CLCThrottlePipelineStorage clc_throttle; alignas(16) arch::ClusterBarrier tmem_dealloc; } pipelines; @@ -193,7 +185,6 @@ class ConvUniversal< EpilogueTensorStorage epilogue; MainloopTensorStorage mainloop; } tensors; - }; static constexpr int SharedStorageSize = sizeof(SharedStorage); @@ -207,7 +198,7 @@ class ConvUniversal< KernelHardwareInfo hw_info{}; TileSchedulerArguments scheduler{}; }; - + // Kernel device entry point API struct Params { using ProblemShapeMNKL = decltype(CollectiveMainloop::get_problem_shape_MNKL(ProblemShape{})); @@ -398,7 +389,7 @@ class ConvUniversal< : WarpCategory::Epilogue; uint32_t lane_predicate = cute::elect_one_sync(); - auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape()); + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}); int cluster_size = size(cluster_shape); uint32_t cta_rank_in_cluster = cute::block_rank_in_cluster(); bool is_first_cta_in_cluster = cta_rank_in_cluster == 0; @@ -407,24 +398,23 @@ class ConvUniversal< constexpr bool has_mma_peer_cta = size(AtomThrShapeMNK{}) == 2; [[maybe_unused]] uint32_t mma_peer_cta_rank = has_mma_peer_cta ? cta_rank_in_cluster ^ 1 : cta_rank_in_cluster; - // Issue Tma Descriptor Prefetch from a single thread - if ((warp_category == WarpCategory::Sched) && lane_predicate) { - CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); - } - if ((warp_category == WarpCategory::EpilogueLoad) && lane_predicate) { - CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); - } - // Kernel level shared memory storage SharedStorage& shared_storage = *reinterpret_cast(smem_buf); // In a warp specialized kernel, collectives expose data movement and compute operations separately - CollectiveMainloop collective_mainloop(params.mainloop); + CollectiveMainloop collective_mainloop(params.mainloop, cluster_shape, cta_rank_in_cluster); CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_category == WarpCategory::Sched) && lane_predicate) { + collective_mainloop.prefetch_tma_descriptors(); + } + if ((warp_category == WarpCategory::EpilogueLoad) && lane_predicate) { + collective_epilogue.prefetch_tma_descriptors(params.epilogue); + } + // Do we load source tensor C or other aux inputs bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); - IsParticipant is_participant = { (warp_category == WarpCategory::MMA), // mma (warp_category == WarpCategory::Sched) && is_first_cta_in_cluster, // sched @@ -462,7 +452,7 @@ class ConvUniversal< epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads; epi_load_pipeline_params.consumer_arv_count = NumEpilogueThreads; epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; - epi_load_pipeline_params.initializing_warp = 4; + epi_load_pipeline_params.initializing_warp = 1; EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); // Epilogue Store pipeline @@ -474,7 +464,7 @@ class ConvUniversal< typename LoadOrderBarrier::Params load_order_barrier_params; load_order_barrier_params.group_id = (warp_category == WarpCategory::MainloopLoad) ? 0 : 1; load_order_barrier_params.group_size = NumMainloopLoadThreads; - load_order_barrier_params.initializing_warp = 5; + load_order_barrier_params.initializing_warp = 3; LoadOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, load_order_barrier_params); // CLC pipeline @@ -493,7 +483,7 @@ class ConvUniversal< clc_pipeline_params.consumer_arv_count += cluster_size * NumEpilogueLoadThreads; } clc_pipeline_params.transaction_bytes = CLCResponseSize; - clc_pipeline_params.initializing_warp = 1; + clc_pipeline_params.initializing_warp = 4; CLCPipeline clc_pipeline(shared_storage.pipelines.clc, clc_pipeline_params, cluster_shape); // Mainloop-Epilogue pipeline @@ -507,29 +497,13 @@ class ConvUniversal< // Only one producer thread arrives on this barrier. accumulator_pipeline_params.producer_arv_count = 1; accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueThreads; - accumulator_pipeline_params.initializing_warp = 2; + accumulator_pipeline_params.initializing_warp = 5; AccumulatorPipeline accumulator_pipeline(shared_storage.pipelines.accumulator, accumulator_pipeline_params, cluster_shape, cute::true_type{}, // Perform barrier init cute::false_type{}); // Delay mask calculation - // CLC throttle pipeline - typename CLCThrottlePipeline::Params clc_throttle_pipeline_params; - if (WarpCategory::MainloopLoad == warp_category) { - clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Producer; - } - if (WarpCategory::Sched == warp_category) { - clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Consumer; - } - clc_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; - clc_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; - clc_throttle_pipeline_params.dst_blockid = 0; - clc_throttle_pipeline_params.initializing_warp = 3; - CLCThrottlePipeline clc_throttle_pipeline(shared_storage.pipelines.clc_throttle, clc_throttle_pipeline_params); - CLCThrottlePipelineState clc_pipe_throttle_consumer_state; - CLCThrottlePipelineState clc_pipe_throttle_producer_state = cutlass::make_producer_start_state(); - // Tmem allocator TmemAllocator tmem_allocator{}; @@ -544,12 +518,10 @@ class ConvUniversal< // We need this to guarantee that the Pipeline init is visible // To all producers and consumer threadblocks in the cluster - if (cluster_size > 1) { - cute::cluster_arrive_relaxed(); - } - else { - __syncthreads(); - } + pipeline_init_arrive_relaxed(cluster_size); + + auto load_inputs = collective_mainloop.load_init( + problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop); uint32_t tmem_stage_ptrs[AccumulatorPipelineStageCount]; MainloopPipelineState mainloop_pipe_consumer_state; @@ -571,7 +543,7 @@ class ConvUniversal< // Calculate mask after cluster barrier arrival mainloop_pipeline.init_masks(cluster_shape, block_id_in_cluster); - accumulator_pipeline.init_masks(cluster_shape); + accumulator_pipeline.init_masks(cluster_shape, block_id_in_cluster); // TileID scheduler TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, problem_shape_MNKL, TileShape{}, block_id_in_cluster); @@ -583,58 +555,13 @@ class ConvUniversal< int TmemColumnsPerAccumulatorTile = cutlass::detail::find_tmem_tensor_col_offset(accumulators); pipeline_init_wait(cluster_size); - if (is_participant.sched) { - - // Whether a new CLC query must be performed. - // See comment below where this variable is updated for a description of - // why this variable is needed. - bool requires_clc_query = true; - - do { - if (requires_clc_query) { - // Throttle CLC query to mitigate workload imbalance caused by skews among persistent workers. - clc_throttle_pipeline.consumer_wait(clc_pipe_throttle_consumer_state); - clc_throttle_pipeline.consumer_release(clc_pipe_throttle_consumer_state); - ++clc_pipe_throttle_consumer_state; - - // Query next clcID and update producer state - clc_pipe_producer_state = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); - } - - // Fetch next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( - work_tile_info, - clc_pipeline, - clc_pipe_consumer_state - ); - - // Only perform a new CLC query if we consumed a new CLC query result in - // `fetch_next_work`. An example of a case in which CLC `fetch_next_work` does - // not consume a new CLC query response is when processing stream-K units. - // The current stream-K scheduler uses single WorkTileInfo to track multiple - // (potentially-partial) tiles to be computed via stream-K. In this case, - // `fetch_next_work` simply performs in-place updates on the existing WorkTileInfo, - // rather than consuming a CLC query response. - requires_clc_query = increment_pipe; - if (increment_pipe) { - ++clc_pipe_consumer_state; - } - - work_tile_info = next_work_tile_info; - } while (work_tile_info.is_valid()); - clc_pipeline.producer_tail(clc_pipe_producer_state); - } - else if (is_participant.main_load) { - + if (is_participant.main_load) { // Ensure that the prefetched kernel does not touch // unflushed global memory prior to this instruction cutlass::arch::wait_on_dependent_grids(); bool do_load_order_arrive = is_epi_load_needed; - auto load_inputs = collective_mainloop.load_init( - problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop); Tensor gA_mk = get<0>(load_inputs); - bool requires_clc_query = true; do { // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. @@ -642,12 +569,6 @@ class ConvUniversal< auto k_tile_count = scheduler.get_work_k_tile_count(work_tile_info, problem_shape_MNKL, TileShape{}); auto k_tile_prologue = min(MainloopPipeline::Stages, k_tile_count); - if (is_first_cta_in_cluster && requires_clc_query) { - clc_throttle_pipeline.producer_acquire(clc_pipe_throttle_producer_state); - clc_throttle_pipeline.producer_commit(clc_pipe_throttle_producer_state); - ++clc_pipe_throttle_producer_state; - } - auto [mainloop_producer_state_next, k_tile_iter_next] = collective_mainloop.load( params.mainloop, mainloop_pipeline, @@ -683,7 +604,6 @@ class ConvUniversal< ); work_tile_info = next_work_tile_info; cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); - requires_clc_query = increment_pipe; if (increment_pipe) { ++clc_pipe_consumer_state; } @@ -691,60 +611,43 @@ class ConvUniversal< collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); } - else if (is_participant.epi_load) { - // Ensure that the prefetched kernel does not touch - // unflushed global memory prior to this instruction - cutlass::arch::wait_on_dependent_grids(); + else if (is_participant.sched) { + // Whether a new CLC query must be performed. + // See comment below where this variable is updated for a description of + // why this variable is needed. + bool requires_clc_query = true; - bool do_load_order_wait = true; - bool do_tail_load = false; do { - bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + if (requires_clc_query) { + // Query next clcID and update producer state + clc_pipe_producer_state = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); + } - // Get current work tile and fetch next work tile + // Fetch next work tile auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( work_tile_info, clc_pipeline, clc_pipe_consumer_state ); - work_tile_info = next_work_tile_info; + // Only perform a new CLC query if we consumed a new CLC query result in + // `fetch_next_work`. An example of a case in which CLC `fetch_next_work` does + // not consume a new CLC query response is when processing stream-K units. + // The current stream-K scheduler uses single WorkTileInfo to track multiple + // (potentially-partial) tiles to be computed via stream-K. In this case, + // `fetch_next_work` simply performs in-place updates on the existing WorkTileInfo, + // rather than consuming a CLC query response. + requires_clc_query = increment_pipe; if (increment_pipe) { ++clc_pipe_consumer_state; } - if (compute_epilogue) { - - if (do_load_order_wait) { - load_order_barrier.wait(); - do_load_order_wait = false; - } - - epi_load_pipe_producer_state = collective_epilogue.load( - epi_load_pipeline, - epi_load_pipe_producer_state, - problem_shape_MNKL, - CtaShape_MNK{}, - cta_coord_mnkl, - TileShape{}, - TiledMma{}, - shared_storage.tensors.epilogue - ); - - do_tail_load = true; - } - - // Calculate the cta coordinates of the next work tile - cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + work_tile_info = next_work_tile_info; } while (work_tile_info.is_valid()); - - if (do_tail_load) { - collective_epilogue.load_tail( - epi_load_pipeline, epi_load_pipe_producer_state, - epi_store_pipeline, epi_store_pipe_producer_state); - } + clc_pipeline.producer_tail(clc_pipe_producer_state); } + else if (is_participant.mma) { // Tmem allocation sequence tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); @@ -757,6 +660,7 @@ class ConvUniversal< tmem_stage_ptrs[acc_stage] = tmem_base_ptr + (TmemColumnsPerAccumulatorTile * acc_stage) & cutlass::detail::TmemColMask; } auto mma_inputs = collective_mainloop.mma_init(shared_storage.tensors.mainloop); + do { auto k_tile_count = scheduler.get_work_k_tile_count(work_tile_info, problem_shape_MNKL, TileShape{}); @@ -788,7 +692,6 @@ class ConvUniversal< mma_inputs, k_tile_count ); - accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); } ++accumulator_pipe_producer_state; @@ -802,6 +705,7 @@ class ConvUniversal< // Release the right to allocate before deallocations so that the next CTA can rasterize tmem_allocator.release_allocation_lock(); + // Leader MMA waits for leader + peer epilogues to release accumulator stage if (is_mma_leader_cta) { accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); @@ -816,8 +720,66 @@ class ConvUniversal< // Free entire tmem allocation tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + else if (is_participant.epi_load) { + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_wait = true; + bool do_tail_load = false; + + do { + bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + + // Get current work tile and fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + if (compute_epilogue) { + if (do_load_order_wait) { + load_order_barrier.wait(); + do_load_order_wait = false; + } + epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + shared_storage.tensors.epilogue + ); + + do_tail_load = true; + } + + // Calculate the cta coordinates of the next work tile + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + + // Only perform a tail load if one of the work units processed performed + // an epilogue load. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_load) { + collective_epilogue.load_tail( + epi_load_pipeline, epi_load_pipe_producer_state, + epi_store_pipeline, epi_store_pipe_producer_state); + } } + else if (is_participant.epilogue) { // Wait for tmem allocate here tmem_allocation_result_barrier.arrive_and_wait(); @@ -875,13 +837,16 @@ class ConvUniversal< epi_load_pipe_consumer_state = load_state_next; epi_store_pipe_producer_state = store_state_next; accumulator_pipe_consumer_state = acc_state_next; - do_tail_store = true; } work_tile_info = next_work_tile_info; cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); } while (work_tile_info.is_valid()); + // Only perform a tail store if one of the work units processed performed + // an epilogue. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). if (do_tail_store) { collective_epilogue.store_tail( epi_load_pipeline, epi_load_pipe_consumer_state, @@ -889,19 +854,8 @@ class ConvUniversal< CtaShape_MNK{}); } } - } - -private: - // Synchronization call. Blocks until barriers are initialized in shared memory. - CUTLASS_DEVICE - void - pipeline_init_wait(int cluster_size) { - if (cluster_size > 1) { - cute::cluster_wait(); - } else { - __syncthreads(); } } }; diff --git a/include/cutlass/detail/sm100_blockwise_scale_layout.hpp b/include/cutlass/detail/blockwise_scale_layout.hpp similarity index 67% rename from include/cutlass/detail/sm100_blockwise_scale_layout.hpp rename to include/cutlass/detail/blockwise_scale_layout.hpp index 8f75bd2561..2d545bbd1e 100644 --- a/include/cutlass/detail/sm100_blockwise_scale_layout.hpp +++ b/include/cutlass/detail/blockwise_scale_layout.hpp @@ -179,11 +179,110 @@ struct Sm100BlockwiseScaleConfig { }; +template +struct RuntimeBlockwiseScaleConfig { + + using ShapeSFA = Shape, Shape, int32_t>; + using ShapeSFB = Shape, Shape, int32_t>; + + using StrideSFA = conditional_t,Stride<_0,int32_t>, int32_t>, + Stride,Stride<_0,_1>, int32_t>>; + + using StrideSFB = conditional_t,Stride<_0,int32_t>, int32_t>, + Stride,Stride<_0,_1>, int32_t>>; + + using LayoutSFA = Layout; + using LayoutSFB = Layout; + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutSFA() { + return LayoutSFA{}; + } + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutSFB() { + return LayoutSFB{}; + } + + // The following function is provided for user fill dynamic problem size to the layout_SFA. + template + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_SFA(ProblemShape problem_shape, SFVecShape sf_vec_shape) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + auto [M, N, K, L] = problem_shape_MNKL; + auto [sfm, sfn, sfk] = sf_vec_shape; + if constexpr (majorSFA == UMMA::Major::MN) { + return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, cute::ceil_div(M, sfm))); + } + else { + return make_stride(make_stride(_0{}, cute::ceil_div(K, sfk)), make_stride(_0{}, _1{})); + } + }(); + + auto [M, N, K, L] = problem_shape_MNKL; + auto [sfm, sfn, sfk] = sf_vec_shape; + auto mk_layout = make_layout( + make_shape(make_shape(sfm, cute::ceil_div(M, sfm)), + make_shape(sfk, cute::ceil_div(K, sfk))), + strides + ); + + return make_layout(append(shape(mk_layout), L), append(stride(mk_layout), size(filter_zeros(mk_layout)))); + } + + // The following function is provided for user fill dynamic problem size to the layout_SFB. + template + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_SFB(ProblemShape problem_shape, SFVecShape sf_vec_shape) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + auto [M, N, K, L] = problem_shape_MNKL; + auto [sfm, sfn, sfk] = sf_vec_shape; + + if constexpr (majorSFB == UMMA::Major::MN) { + return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, cute::ceil_div(N, sfn))); + } + else { + return make_stride(make_stride(_0{}, cute::ceil_div(K, sfk)), make_stride(_0{}, _1{})); + } + }(); + + auto [M, N, K, L] = problem_shape_MNKL; + auto [sfm, sfn, sfk] = sf_vec_shape; + auto nk_layout = make_layout( + make_shape(make_shape(sfn, cute::ceil_div(N, sfn)), + make_shape(sfk, cute::ceil_div(K, sfk))), + strides + ); + + return make_layout(append(shape(nk_layout), L), append(stride(nk_layout), size(filter_zeros(nk_layout)))); + } + +}; + +// Sm90 only supports MN major for SFA and SFB for now +template +using Sm90BlockwiseScaleConfig = Sm100BlockwiseScaleConfig; + template constexpr auto sm100_trivial_blockwise_scale_config(MmaTileShape_MNK) { return Sm100BlockwiseScaleConfig(MmaTileShape_MNK{}), size<1>(MmaTileShape_MNK{}), size<2>(MmaTileShape_MNK{})>{}; } +template +constexpr auto sm90_trivial_blockwise_scale_config(MmaTileShape_MNK) { + return Sm90BlockwiseScaleConfig(MmaTileShape_MNK{}), size<1>(MmaTileShape_MNK{}), size<2>(MmaTileShape_MNK{})>{}; +} + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::detail diff --git a/include/cutlass/detail/helper_macros.hpp b/include/cutlass/detail/helper_macros.hpp index 758b52d3a0..94634e950f 100644 --- a/include/cutlass/detail/helper_macros.hpp +++ b/include/cutlass/detail/helper_macros.hpp @@ -217,6 +217,35 @@ namespace cutlass { //////////////////////////////////////////////////////////////////////////////////////////////////// +// __CUDA_ARCH_SPECIFIC__ is introduced in CUDA 12.9 +#if !defined(CUDA_ARCH_CONDITIONAL) + +#if defined(__CUDA_ARCH_SPECIFIC__) +#define CUDA_ARCH_CONDITIONAL(ARCH_XXYY) (__CUDA_ARCH_SPECIFIC__ == ARCH_XXYY) +#else +#define CUDA_ARCH_CONDITIONAL(ARCH_XXYY) (false) +#endif + +#endif + +// __CUDA_ARCH_FAMILY_SPECIFIC__ is introduced in CUDA 12.9 +#if !defined(CUDA_ARCH_FAMILY) + +#if defined(__CUDA_ARCH_FAMILY_SPECIFIC__) +#define CUDA_ARCH_FAMILY(ARCH_XXYY) (__CUDA_ARCH_FAMILY_SPECIFIC__ == ARCH_XXYY) +#else +#define CUDA_ARCH_FAMILY(ARCH_XXYY) (false) +#endif + +#endif + +#if !defined(CUDA_ARCH_CONDITIONAL_OR_FAMILY) +#define CUDA_ARCH_CONDITIONAL_OR_FAMILY(ARCH_XXYY) \ + (CUDA_ARCH_CONDITIONAL(ARCH_XXYY) || CUDA_ARCH_FAMILY(ARCH_XXYY)) +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + }; // namespace cutlass //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/detail/layout.hpp b/include/cutlass/detail/layout.hpp index a0a183b0ee..562adc65ea 100644 --- a/include/cutlass/detail/layout.hpp +++ b/include/cutlass/detail/layout.hpp @@ -33,10 +33,10 @@ #include "cute/layout.hpp" #include "cute/pointer_sparse.hpp" // cute::is_sparse #include "cute/swizzle.hpp" // cute::Swizzle -#include "cute/swizzle_layout.hpp" // cute::detail::get_swizzle_portion +#include "cute/swizzle_layout.hpp" // cute::get_swizzle_portion #include "cute/util/type_traits.hpp" #include "cute/arch/copy_sm90_tma.hpp" -#include "cute/arch/copy_sm100_tma.hpp" +#include "cute/arch/copy_sm100_tma.hpp" #include "cutlass/layout/matrix.h" #include "cutlass/layout/tensor.h" @@ -219,8 +219,8 @@ stride_to_layout_tag_A() { return layout::ColumnMajor{}; } // Specialize for sparse layout - else if constexpr (cute::get<0>(InternalStrideA{}) == cute::_2{} && - cute::rank(cute::get<1>(InternalStrideA{})) == 2 && + else if constexpr (cute::get<0>(InternalStrideA{}) == cute::_2{} && + cute::rank(cute::get<1>(InternalStrideA{})) == 2 && cute::is_same_v(InternalStrideA{}))>>) { return layout::ColumnMajor{}; } @@ -308,8 +308,8 @@ constexpr bool is_tma_copy_engine() { || cute::is_base_of_v || cute::is_base_of_v || cute::is_base_of_v - || cute::is_base_of_v - || cute::is_base_of_v + || cute::is_base_of_v + || cute::is_base_of_v ) { return true; } @@ -349,7 +349,7 @@ get_alignment_count_from_gmem_tiled_copy() { cutlass::gemm::collective::detail::is_sm10x_f8f6f4_element() && cute::is_same_v::type, uint8_t>) { return 128; } - + // For sparse MMA, alignment in logical elements is increased by sparsity factor if constexpr (cute::is_sparse_v) { return 128 / sizeof_bits::value * ElementMma::sparsity; @@ -366,7 +366,7 @@ get_alignment_count_from_gmem_tiled_copy() { // Return alignment bit requirements for the GEMM inputs. template < class ElementType - , bool IsF8F6F4SubBytes=false + , bool IsF8F6F4SubBytes=false > constexpr int get_input_alignment_bits() { @@ -383,12 +383,12 @@ get_input_alignment_bits() { template constexpr int get_output_alignment_bits() { - + if constexpr (sizeof_bits::value == 6) { // U6 format : The inner tensor size dimension must be a multiple of 96B. return 96 * 8; } - + return 128; } @@ -424,7 +424,7 @@ template CUTLASS_HOST_DEVICE constexpr size_t alignment_for_swizzle(Layout layout) { - return alignment_for_swizzle(cute::detail::get_swizzle_portion(layout)); + return alignment_for_swizzle(cute::get_swizzle_portion(layout)); } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/builders/sm100_builder.inl b/include/cutlass/epilogue/collective/builders/sm100_builder.inl index 16eb4fc9f4..176b1f257f 100644 --- a/include/cutlass/epilogue/collective/builders/sm100_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm100_builder.inl @@ -866,6 +866,45 @@ struct CallbacksBuilder< >; }; +// ptr array aux fusion callbacks builder for sm100 tma epilogue +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class FusionOp, + class CtaTileShape_MNK, + class EpilogueTile_MN, + class ElementAccumulator, + class AccLoadOp +> +struct CallbacksBuilder< + Sm100PtrArrayTmaWarpSpecialized, + FusionOp, + CtaTileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + AccLoadOp, + cute::enable_if_t<(FusionOp::IsAuxOutSupported ^ FusionOp::IsAuxInSupported) // only one aux tensor + && not cute::is_subbyte_v> +> { + using GmemStrideTypeAux = gemm::TagToStrideC_t; + using SmemLayoutAtomAux = decltype(detail::sm100_get_epilogue_smem_swizzle_layout_atom< + GmemStrideTypeAux, typename FusionOp::ElementAux, EpilogueTile_MN>()); + using CopyOpR2S = decltype(detail::sm100_get_smem_store_op< + GmemStrideTypeAux, typename FusionOp::ElementAux, ElementAccumulator, AccLoadOp>()); + using CopyOpS2R = decltype(detail::sm100_get_smem_load_op< + GmemStrideTypeAux, typename FusionOp::ElementAux, ElementAccumulator, AccLoadOp>()); + using SmemCopyOpAux = cute::conditional_t; + + using Callbacks = fusion::FusionCallbacks< + Sm100PtrArrayTmaWarpSpecialized, + FusionOp, CtaTileShape_MNK, EpilogueTile_MN, + SmemLayoutAtomAux, SmemCopyOpAux + >; +}; + template < int StagesC, int StagesD, @@ -930,7 +969,7 @@ template < class ElementC_, class GmemLayoutTagC_, int AlignmentC, - class ElementD, + class ElementD_, class GmemLayoutTagD, int AlignmentD, class Schedule, @@ -943,6 +982,9 @@ private: static_assert(Is1SmMma ^ Is2SmMma, "unsupported schedule"); static_assert(not (Is2SmMma && size<0>(ClusterShape_MNK{}) % 2 == 1), "schedule + cluster mismatch"); + static constexpr bool DisableDestination = cute::is_void_v; + using ElementD = cute::conditional_t,ElementD_>; // prevents void ref breakages + // Passing void C disables source load + smem allocation static constexpr bool DisableSource = cute::is_void_v; using ElementC = cute::conditional_t; // prevents void ref breakages @@ -1168,7 +1210,7 @@ public: EpilogueTile_MN, ElementC_, // Need to pass void through to expose via GemmUniversal GmemStrideTypeC, - ElementD, + ElementD_, // Need to pass void through to expose via GemmUniversal GmemStrideTypeD, decltype(fusion_callbacks()), AccLoadOp, diff --git a/include/cutlass/epilogue/collective/builders/sm120_builder.inl b/include/cutlass/epilogue/collective/builders/sm120_builder.inl index ad1f44a062..e1c1bff803 100644 --- a/include/cutlass/epilogue/collective/builders/sm120_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm120_builder.inl @@ -63,13 +63,27 @@ struct EpilogueSFVecSize> static constexpr int value = FusionOp::SFVecSize; }; +// Helper to deduce NumEpilogueWarpGroups based on Schedule +template +struct GetNumEpilogueWarpGroups { + static constexpr int value = 2; +}; + +template +struct GetNumEpilogueWarpGroups> { + static constexpr int value = Schedule::NumEpilogueWarpGroups; +}; + // Returns the parameterized dispatch policy for the TMA epilogue -template +template constexpr auto sm120_get_tma_dispatch_policy() { using namespace cute; constexpr int EpiTiles = size(shape_div(take<0,2>(TileShapeMNK{}), EpilogueTileMN{})); + using StrideD = cutlass::detail::TagToStrideC_t; + using InternalStrideD = cute::remove_pointer_t; + constexpr bool IsGroupedGemmKernel = !cute::is_same_v; // For 120, a FragmentSize of 4 is used to match the // output per thread from each MMA. Epilogue subtiles iterate over multiple of these @@ -86,9 +100,17 @@ sm120_get_tma_dispatch_policy() { // SM120 epilogues use smaller stage counts in order to fit within the limited shared memory capacity. constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 2), StagesD+1) - : StagesD; - - return Sm120TmaWarpSpecialized{}; + : StagesD; + + constexpr int NumEpilogueWarpGroups = GetNumEpilogueWarpGroups::value; + + if constexpr (IsGroupedGemmKernel) { + return Sm120PtrArrayTmaWarpSpecialized{}; + } + else { + return Sm120TmaWarpSpecialized{}; + } } // Returns the smem layout atom to be used for C or D matrix @@ -291,6 +313,9 @@ struct Sm120TmaBuilderImpl { using GmemStrideTypeC = cutlass::detail::TagToStrideC_t; using GmemStrideTypeD = cutlass::detail::TagToStrideC_t; + using UnderlyingGmemStrideTypeC = cute::remove_pointer_t; + using UnderlyingGmemStrideTypeD = cute::remove_pointer_t; + using CopyOpS2G = cute::conditional_t, SM90_TMA_STORE_IM2COL, @@ -306,15 +331,15 @@ struct Sm120TmaBuilderImpl { // Get the smallest tiled copy we can use to retile the accumulators using CopyAtomC = Copy_Atom; - using SmemLayoutAtomC = decltype(detail::sm120_get_epilogue_smem_swizzle_layout_atom()); - using SmemLayoutAtomD = decltype(detail::sm120_get_epilogue_smem_swizzle_layout_atom()); + using SmemLayoutAtomC = decltype(detail::sm120_get_epilogue_smem_swizzle_layout_atom()); + using SmemLayoutAtomD = decltype(detail::sm120_get_epilogue_smem_swizzle_layout_atom()); - using CopyOpS2R = decltype(detail::sm120_get_smem_load_op_for_source()); + using CopyOpS2R = decltype(detail::sm120_get_smem_load_op_for_source()); - using CopyOpR2S = decltype(detail::sm120_get_smem_store_op_for_accumulator()); + using CopyOpR2S = decltype(detail::sm120_get_smem_store_op_for_accumulator()); // Get register to register tiled copy that happen before shared memory store. - using CopyOpR2R = decltype(detail::sm120_get_register_transform_op()); + using CopyOpR2R = decltype(detail::sm120_get_register_transform_op()); // TMA builder allows for passing callbacks directly, which is either a fusion::FusionCallbacks // instance or a direct visitor implementation, e.g. fusion::Sm90LinearCombination @@ -334,8 +359,32 @@ struct Sm120TmaBuilderImpl { constexpr static bool ReuseSmemC = DispatchPolicy::ReuseSmemC; constexpr static bool DelayTmaStore = DispatchPolicy::DelayTmaStore; + //Helper to deduce BaseDispatchPolicy based on DispatchPolicy + template + struct GetBaseDispatchPolicy { + using Type = T; + }; + + template + struct GetBaseDispatchPolicy> { + using Type = typename cutlass::epilogue::Sm90PtrArrayTmaWarpSpecialized; + }; + + template + struct GetBaseDispatchPolicy> { + using Type = typename cutlass::epilogue::Sm90TmaWarpSpecialized; + }; + + using BaseDispatchPolicy = typename GetBaseDispatchPolicy::Type; + using CollectiveOp = cutlass::epilogue::collective::CollectiveEpilogue< - Sm90TmaWarpSpecialized, + BaseDispatchPolicy, TileShape_MNK, EpilogueTile_MN, ElementC_, // Need to pass void through to expose via GemmUniversal @@ -394,13 +443,15 @@ struct CollectiveBuilder< cute::enable_if_t || cute::is_same_v || cute::is_same_v || + cute::is_same_v || + cute::is_same_v || cute::is_same_v >> { private: using EpilogueTile_MN = decltype(detail::sm120_compute_tile_shape_or_override, FusionOperation>()); using DispatchPolicy = - decltype(detail::sm120_get_tma_dispatch_policy, Schedule>()); + decltype(detail::sm120_get_tma_dispatch_policy()); public: diff --git a/include/cutlass/epilogue/collective/builders/sm90_builder.inl b/include/cutlass/epilogue/collective/builders/sm90_builder.inl index f684437580..9cb03fdc21 100644 --- a/include/cutlass/epilogue/collective/builders/sm90_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm90_builder.inl @@ -116,13 +116,13 @@ sm90_compute_tile_shape_or_override() { auto epi_tile = [&] () { if constexpr (detail::sm90_is_cooperative_v) { auto tile_m = cute::min(_128{}, size<0>(TileShape_MNK{})); - auto tile_n = cute::min(_32{}, size<1>(TileShape_MNK{})); + auto tile_n = cute::gcd(cute::min(_32{}, size<1>(TileShape_MNK{})), size<1>(TileShape_MNK{})); return make_shape(tile_m, tile_n); } else if constexpr (detail::sm90_is_warp_specialized_v) { constexpr int N_perf = sizeof_bits_v == 8 ? 64 : 32; auto tile_m = cute::min(_64{}, size<0>(TileShape_MNK{})); - auto tile_n = cute::min(Int{}, size<1>(TileShape_MNK{})); + auto tile_n = cute::gcd(cute::min(Int{}, size<1>(TileShape_MNK{})), size<1>(TileShape_MNK{})); return make_shape(tile_m, tile_n); } else { @@ -206,6 +206,46 @@ struct CallbacksBuilder< >; }; +// ptr array aux fusion callbacks builder for sm90 tma epilogue +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + class FusionOp, + class TileShape_MNK, + class EpilogueTile_MN, + class AccLoadOp, + class ElementAccumulator +> +struct CallbacksBuilder< + Sm90PtrArrayTmaWarpSpecialized, + FusionOp, + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + AccLoadOp, + cute::enable_if_t<(FusionOp::IsAuxOutSupported ^ FusionOp::IsAuxInSupported) // only one aux tensor + && not cute::is_subbyte_v> // aux subbyte tensor doesn't use smem +> { + using GmemStrideTypeAux = gemm::TagToStrideC_t; + using SmemLayoutAtomAux = decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom< + GmemStrideTypeAux, typename FusionOp::ElementAux, EpilogueTile_MN>()); + using CopyOpR2S = decltype(detail::sm90_get_smem_store_op_for_accumulator< + GmemStrideTypeAux, typename FusionOp::ElementAux>()); + using CopyOpS2R = decltype(detail::sm90_get_smem_load_op_for_source< + GmemStrideTypeAux, typename FusionOp::ElementAux>()); + using SmemCopyOpAux = cute::conditional_t; + + using Callbacks = fusion::FusionCallbacks< + Sm90PtrArrayTmaWarpSpecialized, + FusionOp, TileShape_MNK, EpilogueTile_MN, + SmemLayoutAtomAux, SmemCopyOpAux + >; +}; + template < int StagesC, int StagesD, diff --git a/include/cutlass/epilogue/collective/default_epilogue.hpp b/include/cutlass/epilogue/collective/default_epilogue.hpp index b7bd6f4077..0d019b1c8c 100644 --- a/include/cutlass/epilogue/collective/default_epilogue.hpp +++ b/include/cutlass/epilogue/collective/default_epilogue.hpp @@ -35,6 +35,7 @@ #pragma once #include "cutlass/cutlass.h" +#include "cutlass/arch/memory.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/epilogue/collective/detail.hpp" @@ -225,22 +226,27 @@ class DefaultEpilogue { return; } + using FragCType = remove_cvref_t; + using FragDType = remove_cvref_t; + // source is needed if (epilogue_op.is_source_needed()) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(accumulators); ++i) { - if (elem_less(tCcD(i), residue_tCcD)) { - tCgD(i) = epilogue_op(accumulators(i), tCgC(i)); - } + FragCType fragC; + bool pred = elem_less(tCcD(i), residue_tCcD); + arch::global_load(fragC, &tCgC(i), pred); + FragDType fragD = epilogue_op(accumulators(i), fragC); + arch::global_store(fragD, &tCgD(i), pred); } } // source is not needed, avoid load else { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(accumulators); ++i) { - if (elem_less(tCcD(i), residue_tCcD)) { - tCgD(i) = epilogue_op(accumulators(i)); - } + bool pred = elem_less(tCcD(i), residue_tCcD); + FragDType fragD = epilogue_op(accumulators(i)); + arch::global_store(fragD, &tCgD(i), pred); } } } diff --git a/include/cutlass/epilogue/collective/detail.hpp b/include/cutlass/epilogue/collective/detail.hpp index 2759d0c638..2c72c30168 100644 --- a/include/cutlass/epilogue/collective/detail.hpp +++ b/include/cutlass/epilogue/collective/detail.hpp @@ -124,6 +124,23 @@ struct sm90_is_ptr_array_tma_dispatch_policy< NumEpilogueWarpGroups>> : cute::true_type {}; +template< + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups +> +struct sm90_is_ptr_array_tma_dispatch_policy< + Sm120PtrArrayTmaWarpSpecialized> + : cute::true_type {}; + template static constexpr bool sm90_is_ptr_array_tma_dispatch_policy_v = sm90_is_ptr_array_tma_dispatch_policy::value; diff --git a/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp index 9c24913e9a..b9fb5320c1 100644 --- a/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp @@ -129,8 +129,13 @@ class CollectiveEpilogue< static_assert(rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); private: - using GmemElementD = ElementD; - using GmemElementC = cute::conditional_t,ElementD,ElementC>; // prevents void ref breakages + + constexpr static bool is_source_supported = not cute::is_void_v; + constexpr static bool is_destination_supported = not cute::is_void_v; + using GmemElementD = cute::conditional_t>; + using GmemElementC = cute::conditional_t; // prevents void ref breakages + static_assert(not cute::is_void_v, "GmemElementD is void"); + using SmemElementD = typename cutlass::detail::get_unpacked_element_type::type; using SmemElementC = typename cutlass::detail::get_unpacked_element_type::type; constexpr static int StagesC = StagesC_; @@ -138,9 +143,7 @@ class CollectiveEpilogue< static_assert(StagesC >= 1, "StagesC must be >= 1"); static_assert(StagesD >= 1, "StagesD must be >= 1"); - constexpr static bool ReuseSmemC = ReuseSmemC_; - constexpr static bool DelayTmaStore = DelayTmaStore_; - constexpr static bool is_source_supported = not cute::is_void_v; + constexpr static bool ReuseSmemC = ReuseSmemC_ && is_destination_supported; constexpr static bool is_m_major_C = detail::is_m_major(); constexpr static bool is_m_major_D = detail::is_m_major(); @@ -159,7 +162,7 @@ class CollectiveEpilogue< using SmemLayoutC = decltype(cute::append<3>(SmemLayoutStageC{}, Layout, Int>{})); using SmemLayoutD = decltype(cute::append<3>(SmemLayoutStageD{}, Layout, Int>{})); - constexpr static bool support_smem_reuse = is_source_supported && StagesD <= StagesC + constexpr static bool support_smem_reuse = is_source_supported && is_destination_supported && StagesD <= StagesC && MaxStageBits % sizeof_bits_v == 0 && MaxStageBits % sizeof_bits_v == 0; static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met"); @@ -168,6 +171,12 @@ class CollectiveEpilogue< constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); constexpr static size_t MaxSmemAlignment = cute::max(SmemAlignmentC, SmemAlignmentD); + // Not unroll epi subtile loop when the activation op is heavy to reduce instruction size and register pressure. + constexpr static bool UnrollEpiLoop = + not cutlass::epilogue::thread::kIsHeavy_member_or_false::value; + // TMA store delay only benefits with loop unrolling + constexpr static bool DelayTmaStore = DelayTmaStore_ and UnrollEpiLoop; + struct CollectiveStorageWithC { alignas(SmemAlignmentC) ArrayEngine> smem_C; alignas(SmemAlignmentD) ArrayEngine> smem_D; @@ -239,7 +248,7 @@ class CollectiveEpilogue< using TMA_C = decltype(make_tma_copy( CopyOpG2S{}, make_tensor( - make_gmem_ptr(static_cast,ElementD,ElementC> const*>(nullptr)), + make_gmem_ptr(static_cast(nullptr)), TensorShapeC{}, append<3>(InternalStrideC{}, _0{})), SmemLayoutStageC{}, @@ -248,7 +257,7 @@ class CollectiveEpilogue< using TMA_D = decltype(make_tma_copy( CopyOpS2G{}, make_tensor( - make_gmem_ptr(static_cast(nullptr)), + make_gmem_ptr(static_cast(nullptr)), TensorShapeD{}, append<3>(InternalStrideD{}, _0{})), SmemLayoutStageD{}, @@ -278,6 +287,8 @@ class CollectiveEpilogue< // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. // These will be replaced with correct values before the initial tma load. auto init_shape = repeat_like(append<4>(typename ProblemShape::UnderlyingProblemShape{}, 1), int32_t(1)); + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. constexpr int tma_alignment_bits = 128; auto init_M = tma_alignment_bits; auto init_N = tma_alignment_bits; @@ -308,10 +319,13 @@ class CollectiveEpilogue< tma_load_c = make_tma_copy(CopyOpG2S{}, tensor_c, SmemLayoutStageC{}, EpilogueTile{}, _1{}); } - // Tensor pointers will be fixed before the first access - ElementD* ptr_D_first_batch = nullptr; - Tensor tensor_d = make_tensor(ptr_D_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_d, _0{}))); - typename Params::TMA_D tma_store_d = make_tma_copy(CopyOpS2G{}, tensor_d, SmemLayoutStageD{}, EpilogueTile{}, _1{}); + typename Params::TMA_D tma_store_d{}; + if constexpr (is_destination_supported) { + // Tensor pointers will be fixed before the first access + ElementD* ptr_D_first_batch = nullptr; + Tensor tensor_d = make_tensor(ptr_D_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_d, _0{}))); + tma_store_d = make_tma_copy(CopyOpS2G{}, tensor_d, SmemLayoutStageD{}, EpilogueTile{}, _1{}); + } auto fusion_workspace = static_cast(workspace); auto fusion_workspace_size = round_nearest(FusionCallbacks::get_workspace_size(problem_shape, args.thread), MinTensorMapWorkspaceAlignment); @@ -359,9 +373,11 @@ class CollectiveEpilogue< auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1); auto [M,N,K,L] = problem_shape_MNKL; - constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits(); - constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideD{}); + if constexpr (is_destination_supported) { + constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits(); + constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideD{}); + } if constexpr (is_source_supported) { constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits(); @@ -752,13 +768,9 @@ class CollectiveEpilogue< thread_idx }; - auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); - bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); - bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); - // Thread synchronizer for previously issued waits or fences // to ensure visibility of smem reads/writes to threads or TMA unit - auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + auto synchronize = [] () CUTLASS_LAMBDA_FUNC_INLINE { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; // Predication for sub-128 thread T2R tiled copy Layout tmem_warp_layout = typename decltype(make_tmem_warp_partitioner(tAcc_epi(_,_,0,0)))::TiledLayout_TV{}; @@ -795,31 +807,38 @@ class CollectiveEpilogue< [[maybe_unused]] int epi_n_prev = 0; static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); - auto epi_loop_fn = [&] (auto& cst_callbacks) { - // The TMA store sequence for one subtile iteration - auto tma_store_fn = [&] (int epi_m, int epi_n) { + // The Epilogue Loop + auto epi_loop_fn = [&] (auto& cst_callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // The TMA store sequence for one epilogue loop iteration + auto tma_store_fn = [&] (int epi_m, int epi_n) CUTLASS_LAMBDA_FUNC_INLINE { // Write the tile from smem to gmem with TMA cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA synchronize(); // ensure all threads have issued their async fence - if (issue_tma_store) { - copy(params.tma_store_d.with(get<0>(store_tensormap_info)), bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); - } + if constexpr (is_destination_supported) { + if (issue_tma_store) { + copy(params.tma_store_d.with(get<0>(store_tensormap_info)), bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); + } + } + // Post async fence, pre TMA commit callback entry point cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); - + // Commit the TMA stores for this stage if (issue_tma_store) { store_pipeline.producer_commit(store_pipe_producer_state); } ++store_pipe_producer_state; - + // Wait for the next smem buffer to be available if (issue_tma_store) { store_pipeline.producer_acquire(store_pipe_producer_state); } synchronize(); - + if constexpr (ReuseSmemC) { // producer_acquire returns when at most StagesD-1 committed stores are pending bool store_finished = store_pipe_producer_state.count() > StorePipeline::UnacquiredStages; @@ -831,11 +850,7 @@ class CollectiveEpilogue< ++load_pipe_consumer_state; } } - }; - - // - // BEGIN EPILOGUE - // + }; // tma_store_fn // Begin the wait for the producer load results ConsumerToken load_wait_token{BarrierStatus::WaitDone}; @@ -850,10 +865,12 @@ class CollectiveEpilogue< synchronize(); } // For each epilogue subtile within the CTA tile - CUTLASS_PRAGMA_UNROLL - for (int iter_n = 0; iter_n < size<3>(gD_epi); ++iter_n) { - CUTLASS_PRAGMA_UNROLL - for (int iter_m = 0; iter_m < size<2>(gD_epi); ++iter_m) { + constexpr int NumEpiSubtilesN = CUTE_STATIC_V(size<3>(gD_epi)); + constexpr int NumEpiSubtilesM = CUTE_STATIC_V(size<2>(gD_epi)); + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesN : 1) + for (int iter_n = 0; iter_n < NumEpiSubtilesN; ++iter_n) { + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesM : 1) + for (int iter_m = 0; iter_m < NumEpiSubtilesM; ++iter_m) { int epi_m = iter_m, epi_n = iter_n; bool is_first_iteration = iter_m == 0 && iter_n == 0; bool is_last_iteration = iter_m == size<2>(gD_epi)-1 && iter_n == size<3>(gD_epi)-1; @@ -953,8 +970,10 @@ class CollectiveEpilogue< // Copy output tile from register to smem bool issue_smem_store = issue_tmem_load; - if (issue_smem_store) { - copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); + if constexpr (is_destination_supported) { + if (issue_smem_store) { + copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); + } } // Post reduction, pre TMA store callback entry point @@ -982,9 +1001,11 @@ class CollectiveEpilogue< cst_callbacks.end(); }; - epi_loop_fn(cst_callbacks); - cst_callbacks.end(); - + // + // BEGIN EPILOGUE + // + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + epi_loop_fn(cst_callbacks); return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state, acc_pipe_consumer_state); } @@ -1201,10 +1222,12 @@ class CollectiveEpilogue< } // For each epilogue subtile within the CTA tile - CUTLASS_PRAGMA_UNROLL - for (int iter_n = 0; iter_n < size<3>(gD_epi); ++iter_n) { - CUTLASS_PRAGMA_UNROLL - for (int iter_m = 0; iter_m < size<2>(gD_epi); ++iter_m) { + constexpr int NumEpiSubtilesN = CUTE_STATIC_V(size<3>(gD_epi)); + constexpr int NumEpiSubtilesM = CUTE_STATIC_V(size<2>(gD_epi)); + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesN : 1) + for (int iter_n = 0; iter_n < NumEpiSubtilesN; ++iter_n) { + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesM : 1) + for (int iter_m = 0; iter_m < NumEpiSubtilesM; ++iter_m) { int epi_m = iter_m, epi_n = iter_n; bool is_first_iteration = iter_m == 0 && iter_n == 0; bool is_last_iteration = iter_m == size<2>(gD_epi)-1 && iter_n == size<3>(gD_epi)-1; @@ -1343,7 +1366,7 @@ class CollectiveEpilogue< } syncwarp(); } - } else { + } else if constexpr (is_destination_supported) { int const offset_Ddesc = cute::is_void_v ? 0 : sm_count; tma_desc = &gmem_tensormap[sm_idx + offset_Ddesc]; if (cute::elect_one_sync()) { @@ -1374,7 +1397,7 @@ class CollectiveEpilogue< params.ptr_C[next_batch]); } } - } else { + } else if constexpr (is_destination_supported) { cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_D, params.ptr_D[next_batch]); } @@ -1414,7 +1437,7 @@ class CollectiveEpilogue< } } } - else { + else if constexpr (is_destination_supported) { ElementD const* ptr_D = nullptr; Tensor tensor_d = make_tensor(ptr_D, make_layout(make_shape(M,N,Int<1>{}), params.dD[next_group])); @@ -1464,16 +1487,23 @@ class CollectiveEpilogue< tensormaps_cp_fence_release( TensorMapStorage& shared_tensormap, cute::TmaDescriptor const* tensormap) { + // Commit and wait for all TMA load/store instructions before updating the tensormap in gmem. + // This operation only happens when the group/batch changes between consecutive tiles. + // If there are no uncommitted instructions then tma_desc_commit_group results in an empty bulk async-group. + auto tma_desc_wait_all_fn = [] () CUTLASS_LAMBDA_FUNC_INLINE { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + }; // Entire warp must do this (ie its aligned) if constexpr (IsLoad) { if (is_source_supported) { - if (cute::elect_one_sync()) { - cute::tma_desc_commit_group(); - cute::tma_desc_wait_group(); - } + tma_desc_wait_all_fn(); tma_descriptor_cp_fence_release(tensormap, shared_tensormap.smem_tensormap_C); } - } else { + } else if constexpr (is_destination_supported) { + tma_desc_wait_all_fn(); tma_descriptor_cp_fence_release(tensormap, shared_tensormap.smem_tensormap_D); } } @@ -1486,7 +1516,7 @@ class CollectiveEpilogue< if (is_source_supported) { cute::tma_descriptor_fence_acquire(tensormap); } - } else { + } else if constexpr (is_destination_supported) { cute::tma_descriptor_fence_acquire(tensormap); } } diff --git a/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp b/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp index ba85a75e54..f58f61fcb4 100644 --- a/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp +++ b/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp @@ -462,6 +462,10 @@ class CollectiveEpilogue< || is_same_v; // alloc reduction buffer for custom EVTs constexpr static size_t ImplicitSharedStorageSize = IsReductionBufferNeeded ? size(EpilogueTile{}) : 0; + // Not unroll epi subtile loop when the activation op is heavy to reduce instruction size and register pressure. + constexpr static bool UnrollEpiLoop = + not cutlass::epilogue::thread::kIsHeavy_member_or_false::value; + public: constexpr static int ThreadCount = 128; constexpr static uint32_t TmaTransactionBytes = 0; @@ -646,12 +650,12 @@ class CollectiveEpilogue< thread_idx }; - auto cst_callbacks = fusion_callbacks.get_consumer_store_callbacks(cst_args); - bool is_C_load_needed = fusion_callbacks.is_C_load_needed(); - auto synchronize = [] () CUTLASS_LAMBDA_FUNC_INLINE { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + // The Epilogue Loop auto epi_loop_fn = [&] (auto& cst_callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + bool is_C_load_needed = fusion_callbacks.is_C_load_needed(); + // Ensure there are no threads from the previous wave writing to shared memory being utilized for the current wave. synchronize(); cst_callbacks.begin(); @@ -669,10 +673,12 @@ class CollectiveEpilogue< static_assert(not (ReuseTmem && AccumulatorPipeline::Stages != 1), "Tmem reuse requires 1 accumulator stage"); // For each epilogue subtile within the CTA tile - CUTLASS_PRAGMA_UNROLL - for (int iter_n = 0; iter_n < size<4>(tTR_tAcc); ++iter_n) { - CUTLASS_PRAGMA_UNROLL - for (int iter_m = 0; iter_m < size<3>(tTR_tAcc); ++iter_m) { + constexpr int NumEpiSubtilesN = CUTE_STATIC_V(size<4>(tTR_tAcc)); + constexpr int NumEpiSubtilesM = CUTE_STATIC_V(size<3>(tTR_tAcc)); + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesN : 1) + for (int iter_n = 0; iter_n < NumEpiSubtilesN; ++iter_n) { + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesM : 1) + for (int iter_m = 0; iter_m < NumEpiSubtilesM; ++iter_m) { int epi_m = iter_m, epi_n = iter_n; bool is_last_iteration = iter_m == size<3>(tTR_tAcc)-1 && iter_n == size<4>(tTR_tAcc)-1; @@ -747,6 +753,10 @@ class CollectiveEpilogue< cst_callbacks.end(); }; + // + // BEGIN EPILOGUE + // + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); epi_loop_fn(cst_callbacks); return cute::make_tuple(acc_pipe_consumer_state); } diff --git a/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp index 37acf23ae3..6c3f111c11 100644 --- a/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp @@ -140,7 +140,6 @@ class CollectiveEpilogue< static_assert(StagesD >= 1, "StagesD must be >= 1"); constexpr static bool ReuseSmemC = ReuseSmemC_; - constexpr static bool DelayTmaStore = DelayTmaStore_; constexpr static bool is_source_supported = not cute::is_void_v; constexpr static bool is_m_major_C = detail::is_m_major(); @@ -172,6 +171,12 @@ class CollectiveEpilogue< constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); constexpr static size_t MaxSmemAlignment = cute::max(SmemAlignmentC, SmemAlignmentD); + // Not unroll epi subtile loop when the activation op is heavy to reduce instruction size and register pressure. + constexpr static bool UnrollEpiLoop = + not cutlass::epilogue::thread::kIsHeavy_member_or_false::value; + // TMA store delay only benefits with loop unrolling + constexpr static bool DelayTmaStore = DelayTmaStore_ and UnrollEpiLoop; + struct CollectiveStorageWithC { alignas(SmemAlignmentC) ArrayEngine> smem_C; alignas(SmemAlignmentD) ArrayEngine> smem_D; @@ -687,7 +692,7 @@ class CollectiveEpilogue< // OOB predication for tile quantization "residue" // Absolute coordinate tensors (dynamic) Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) - Tensor cD_mn = local_tile(mD_crd, take<0,2>(cta_tile_mnk), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor cD_mn = local_tile(mD_crd, take<0,2>(cta_tile_mnk), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) Tensor tTR_cD_mn = thread_t2r.partition_D(flat_divide(cD_mn, EpilogueTile{})); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) // Relative coordinate tensors (static) Tensor cD = make_counting_tensor(cD_mn.layout()); // (CTA_M,CTA_N) @@ -696,7 +701,7 @@ class CollectiveEpilogue< auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n) auto residue_tTR_cD = make_coord(M,N) - tTR_cD_mn(_0{}); // (m,n) - // Get the fusion callbacks for the consumer store warps + // Arguments for the fusion callbacks for the consumer store warps constexpr bool RefSrc = false; // Register tensors reference T2R copy dst layout auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ problem_shape_mnkl, @@ -713,10 +718,6 @@ class CollectiveEpilogue< thread_idx }; - auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); - bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); - bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); - // Thread synchronizer for previously issued waits or fences // to ensure visibility of smem reads/writes to threads or TMA unit auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; @@ -756,8 +757,12 @@ class CollectiveEpilogue< [[maybe_unused]] int epi_n_prev = 0; static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); + // The Epilogue Loop auto epi_loop_fn = [&] (auto& cst_callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - // The TMA store sequence for one subtile iteration + bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // The TMA store sequence for one epilogue loop iteration auto tma_store_fn = [&] (int epi_m, int epi_n) CUTLASS_LAMBDA_FUNC_INLINE { // Write the tile from smem to gmem with TMA cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA @@ -765,22 +770,22 @@ class CollectiveEpilogue< if (issue_tma_store) { copy(params.tma_store_d, bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); } - + // Post async fence, pre TMA commit callback entry point cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); - + // Commit the TMA stores for this stage if (issue_tma_store) { store_pipeline.producer_commit(store_pipe_producer_state); } ++store_pipe_producer_state; - + // Wait for the next smem buffer to be available if (issue_tma_store) { store_pipeline.producer_acquire(store_pipe_producer_state); } synchronize(); - + if constexpr (ReuseSmemC) { // producer_acquire returns when at most StagesD-1 committed stores are pending bool store_finished = store_pipe_producer_state.count() > StorePipeline::UnacquiredStages; @@ -792,11 +797,8 @@ class CollectiveEpilogue< ++load_pipe_consumer_state; } } - }; + }; // tma_store_fn - // - // BEGIN EPILOGUE - // cst_callbacks.begin(); if (cst_callbacks.begin_sync_needed()) { synchronize(); @@ -811,10 +813,12 @@ class CollectiveEpilogue< ConsumerToken acc_wait_token = acc_pipeline.consumer_try_wait(acc_pipe_consumer_state); // For each epilogue subtile within the CTA tile - CUTLASS_PRAGMA_UNROLL - for (int iter_n = 0; iter_n < size<3>(gD_epi); ++iter_n) { - CUTLASS_PRAGMA_UNROLL - for (int iter_m = 0; iter_m < size<2>(gD_epi); ++iter_m) { + constexpr int NumEpiSubtilesN = CUTE_STATIC_V(size<3>(gD_epi)); + constexpr int NumEpiSubtilesM = CUTE_STATIC_V(size<2>(gD_epi)); + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesN : 1) + for (int iter_n = 0; iter_n < NumEpiSubtilesN; ++iter_n) { + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesM : 1) + for (int iter_m = 0; iter_m < NumEpiSubtilesM; ++iter_m) { int epi_m = iter_m, epi_n = iter_n; bool is_first_iteration = iter_m == 0 && iter_n == 0; bool is_last_iteration = iter_m == size<2>(gD_epi)-1 && iter_n == size<3>(gD_epi)-1; @@ -941,9 +945,13 @@ class CollectiveEpilogue< } cst_callbacks.end(); - }; + }; // epi_loop_fn - epi_loop_fn(cst_callbacks); + // + // BEGIN EPILOGUE + // + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + epi_loop_fn(cst_callbacks); return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state, acc_pipe_consumer_state); } @@ -1161,10 +1169,12 @@ class CollectiveEpilogue< } // For each epilogue subtile within the CTA tile - CUTLASS_PRAGMA_UNROLL - for (int iter_n = 0; iter_n < size<3>(gD_epi); ++iter_n) { - CUTLASS_PRAGMA_UNROLL - for (int iter_m = 0; iter_m < size<2>(gD_epi); ++iter_m) { + constexpr int NumEpiSubtilesN = CUTE_STATIC_V(size<3>(gD_epi)); + constexpr int NumEpiSubtilesM = CUTE_STATIC_V(size<2>(gD_epi)); + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesN : 1) + for (int iter_n = 0; iter_n < NumEpiSubtilesN; ++iter_n) { + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesM : 1) + for (int iter_m = 0; iter_m < NumEpiSubtilesM; ++iter_m) { int epi_m = iter_m, epi_n = iter_n; bool is_first_iteration = iter_m == 0 && iter_n == 0; bool is_last_iteration = iter_m == size<2>(gD_epi)-1 && iter_n == size<3>(gD_epi)-1; diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp index b5cdfdcb87..c625f43d2e 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp @@ -41,6 +41,7 @@ #include "cutlass/epilogue/thread/scale_type.h" #include "cutlass/epilogue/fusion/callbacks.hpp" #include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp" #include "cutlass/detail/collective.hpp" #include "cutlass/detail/layout.hpp" #include "cutlass/trace.h" @@ -304,8 +305,9 @@ class CollectiveEpilogue< // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. // These will be replaced with correct values before the initial tma load. auto init_shape = repeat_like(append<4>(typename ProblemShape::UnderlyingProblemShape{}, 1), int32_t(1)); - auto init_M = get<0>(init_shape); - auto init_N = get<1>(init_shape); + constexpr int tma_alignment_bits = 128; + auto init_M = tma_alignment_bits; + auto init_N = tma_alignment_bits; auto init_L = get<3>(init_shape); static_assert(!is_im2col_C and !is_im2col_D, "Im2Col not supported on C or D"); @@ -761,7 +763,14 @@ class CollectiveEpilogue< CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M"); + if constexpr (epi_tile_m * epi_tile_n > mma_tile_m * mma_tile_n) { + // When the epilogue subtile is larger than the MMA tiles, loop over multiple MMA tiles + CUTE_STATIC_ASSERT(epi_tile_n % mma_tile_n == 0, "MMA_TILE_N must divide EPI_TILE_N"); + } + else { CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); + } + // Get TiledCopy for partition reference when consumer store. TiledCopy tiled_copy_partition_ref = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); // Get the fusion callbacks for the consumer store warps @@ -784,6 +793,12 @@ class CollectiveEpilogue< bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + using FragmentVisit = decltype(cst_callbacks.visit(tRS_rAcc_frg(0), 0, 0, 0)); + constexpr bool IsDirectR2S = cute::is_same_v>; + using RegisterElementD = cute::conditional_t; + Tensor tRS_rCompute = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + Tensor tRS_rCompute_frg = recast>(tRS_rCompute); + // Thread synchronizer for previously issued waits or fences // to ensure visibility of smem reads/writes to threads or TMA unit auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; @@ -894,17 +909,41 @@ class CollectiveEpilogue< ++load_wait_state; } - int mma_m = epi_m; - int mma_n = (epi_n * size<1>(EpilogueTile{})) / mma_tile_n; - Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n); - - // Vectorized fragment loop with visitor callback entry point - int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n); - int r2s_v = epi_n_in_mma * size(tRS_rD_frg); - CUTLASS_PRAGMA_UNROLL - for (int epi_v = 0; epi_v < size(tRS_rD_frg); ++epi_v) { - tRS_rD_frg(epi_v) = cst_callbacks.visit(tRS_rAcc_frg_mn(r2s_v + epi_v), epi_v, epi_m, epi_n); + if constexpr (epi_tile_m * epi_tile_n > mma_tile_m * mma_tile_n) { + // When the epilogue subtile is larger than the MMA tiles, loop over multiple + // MMA tiles + static constexpr int MmaMPerEpiM = epi_tile_m / mma_tile_m; + static constexpr int MmaNPerEpiN = epi_tile_n / mma_tile_n; + + CUTLASS_PRAGMA_UNROLL + for (int mma_n_in_epi = 0; mma_n_in_epi < MmaNPerEpiN; ++mma_n_in_epi) { + int mma_n = (epi_n * MmaNPerEpiN) + mma_n_in_epi; + + CUTLASS_PRAGMA_UNROLL + for (int mma_m_in_epi = 0; mma_m_in_epi < MmaMPerEpiM; ++mma_m_in_epi) { + int mma_m = (epi_m * MmaMPerEpiM) + mma_m_in_epi; + Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n); + int idx_in_epi_subtile = (mma_n_in_epi * MmaMPerEpiM + mma_m_in_epi); + + tRS_rCompute_frg(idx_in_epi_subtile) = cst_callbacks.visit( + tRS_rAcc_frg_mn(0), idx_in_epi_subtile, epi_m, epi_n); + } + } + } + else { + int mma_m = epi_m; + int mma_n = (epi_n * size<1>(EpilogueTile{})) / mma_tile_n; + Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n); + + // Vectorized fragment loop with visitor callback entry point + int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n); + int r2s_v = epi_n_in_mma * size(tRS_rCompute_frg); + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(tRS_rCompute_frg); ++epi_v) { + tRS_rCompute_frg(epi_v) = cst_callbacks.visit(tRS_rAcc_frg_mn(r2s_v + epi_v), epi_v, epi_m, epi_n); + } } + // The latest we can delay the TMA store is right before the smem store of the next iteration // since the current TMA store needs to be committed before we can acquire the next smem buffer if constexpr (DelayTmaStore) { @@ -918,7 +957,7 @@ class CollectiveEpilogue< // Smem reduction callback entry point using current store buffer for workspace cst_callbacks.reduce(sD_epi(_,_,store_pipe_producer_state.index()), - synchronize, epi_m, epi_n, is_last_iteration, tRS_rD_frg); + synchronize, epi_m, epi_n, is_last_iteration, tRS_rCompute_frg); // Copy tile from register to regiser if needed if constexpr (IsUseR2R) { @@ -930,6 +969,11 @@ class CollectiveEpilogue< copy(tiled_r2r, tRR_rD_src, tRR_rD_dst); } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tRS_rD_frg); ++i) { + tRS_rD_frg(i) = cutlass::NumericArrayConverter{}(tRS_rCompute_frg(i)); + } + // Copy tile from register to smem if constexpr (is_destination_supported) { copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); @@ -1140,7 +1184,6 @@ class CollectiveEpilogue< ProblemShape_MNKL problem_shape_mnkl, int32_t next_batch, int32_t warp_group_idx) { - if (cute::elect_one_sync()) { // Replacing global_address for the next batch tensormaps_replace_global_address(shared_tensormaps, params, next_batch, warp_group_idx); @@ -1161,14 +1204,24 @@ class CollectiveEpilogue< TensorMapStorage& shared_tensormaps, cute::TmaDescriptor const* tensormap, const int32_t warp_group_idx = 0) { - + // Commit and wait for all TMA load/store instructions before updating the tensormap in gmem. + // This operation only happens when the group/batch changes between consecutive tiles. + // If there are no uncommitted instructions then tma_desc_commit_group results in an empty bulk async-group. + auto tma_desc_wait_all_fn = [] () CUTLASS_LAMBDA_FUNC_INLINE { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + }; // Entire warp must do this (ie its aligned) if constexpr (IsLoad) { if constexpr (is_source_supported) { + tma_desc_wait_all_fn(); tma_descriptor_cp_fence_release(tensormap, shared_tensormaps.smem_tensormap_C); } } else if constexpr (is_destination_supported) { + tma_desc_wait_all_fn(); tma_descriptor_cp_fence_release(tensormap, shared_tensormaps.smem_tensormap_D[warp_group_idx]); } } diff --git a/include/cutlass/epilogue/dispatch_policy.hpp b/include/cutlass/epilogue/dispatch_policy.hpp index a2a46b73c9..db53153c50 100644 --- a/include/cutlass/epilogue/dispatch_policy.hpp +++ b/include/cutlass/epilogue/dispatch_policy.hpp @@ -255,6 +255,23 @@ struct Sm120TmaWarpSpecialized { constexpr static bool DelayTmaStore = DelayTmaStore_; }; +template< + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_, + int NumEpilogueWarpGroups_ +> +struct Sm120PtrArrayTmaWarpSpecialized { + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static int FragmentSize = FragmentSize_; + constexpr static bool ReuseSmemC = ReuseSmemC_; + constexpr static bool DelayTmaStore = DelayTmaStore_; + constexpr static int NumEpilogueWarpGroups = NumEpilogueWarpGroups_; +}; + #if defined (SYCL_INTEL_TARGET) // Specialization of the GEMM Epilogue for Intel Xe architectures. // This version is tuned for operations with a subgroup size of 16. diff --git a/include/cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp index 5c47d70627..28099b2116 100644 --- a/include/cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp @@ -78,12 +78,13 @@ namespace detail { } }(); + // norm_constant and qpvscale_rcps are all positive numbers. + auto acc_scales = cutlass::multiplies>{}(norm_constant, qpvscale_rcps); + CUTLASS_PRAGMA_UNROLL for (int sf_v = 0; sf_v < NumVecs; ++sf_v) { - // norm_constant and qpvscale_rcps[sf_v] are all positive numbers. - ElementCompute acc_scale = mul(norm_constant, qpvscale_rcps[sf_v]); // Map INF to fp32::max - acc_scale = minimum_with_nan_propagation{}(acc_scale, cutlass::platform::numeric_limits::max()); + auto acc_scale = minimum_with_nan_propagation{}(acc_scales[sf_v], cutlass::platform::numeric_limits::max()); // Convert to output type output_frgs[sf_v] = cutlass::NumericArrayConverter{}(mul_array(compute_frgs[sf_v], acc_scale)); } @@ -240,17 +241,19 @@ struct Sm100BlockScaleFactorRowStore { cutlass::multiplies mul; cutlass::maximum_absolute_value_reduction, true> amax_reduction; + cutlass::Array vec_maxs; cutlass::Array pvscales; // SF generation CUTLASS_PRAGMA_UNROLL for (int sf_v = 0; sf_v < NumVecs; ++sf_v) { compute_frgs[sf_v] = NumericArrayConverter{}(input_frgs[sf_v]); /// Step1: get max across a vector - ElementCompute vec_max = amax_reduction(ElementCompute(0), compute_frgs[sf_v]); - /// Step2: Compute Scale - pvscales[sf_v] = mul(vec_max, norm_constant_scaled_down); + vec_maxs[sf_v] = amax_reduction(ElementCompute(0), compute_frgs[sf_v]); } + /// Step2: Compute Scale + pvscales = cutlass::multiplies>{}(vec_maxs, norm_constant_scaled_down); + tC_rSFD_frg(_0{}) = cutlass::NumericArrayConverter{}(pvscales); Tensor tCgSFD_flt = filter_zeros(tC_gSFD(_,_,_,_0{},_0{},get<0>(epi_tile_coord_mn) + epi_m, get<1>(epi_tile_coord_mn) + epi_n)); diff --git a/include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp index 8f391aace0..b769b1f0fb 100644 --- a/include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp @@ -1317,6 +1317,277 @@ struct FusionCallbacks< using Impl::Impl; }; +// Sm120 Ptr array tma warp specialized callbacks just alias to their sm90 counterpart +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + class Operation, + class CtaTile_MNK, + class EpilogueTile_MN, + class... Args +> +struct FusionCallbacks< + epilogue::Sm120PtrArrayTmaWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args... +> : FusionCallbacks< + epilogue::Sm90PtrArrayTmaWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args... + > { + using FusionCallbacks< + epilogue::Sm90PtrArrayTmaWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args...>::FusionCallbacks; +}; + +// For Ptr-Array and Grouped GEMM +// D = alpha * acc + beta * C, where alpha and beta can be vectors for each batch/group +// With Row BlockScaleFactor Generation, separate tensors per batch/group. +template< + int SFVecsize, + class EpilogueTile, + class CtaTileShapeMNK, + int FragmentSize, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm120LinearCombRowBlockScaleFactorPtrArray = + Sm90EVT< + Sm120BlockScaleFactorRowStore< + SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, + ElementCompute, ElementBlockScaleFactor *, RoundStyle + >, // gen scalefactor + Sm90LinearCombinationPtrArray< ElementCompute, ElementCompute, + ElementSource, ElementScalar, RoundStyle + > // beta * C + (alpha * acc) + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm120PtrArrayTmaWarpSpecialized, + fusion::LinCombBlockScaleFactor< + SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementSource, ElementScalar, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm120LinearCombRowBlockScaleFactorPtrArray< + SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle + > { + + using Impl = + Sm120LinearCombRowBlockScaleFactorPtrArray< + SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle + >; + + using Operation = + fusion::LinCombBlockScaleFactor< + SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementSource, ElementScalar, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementScalar const* const* alpha_ptr_array = nullptr; + ElementScalar const* const* beta_ptr_array = nullptr; + ElementBlockScaleFactor ** block_scale_factor_ptr = nullptr; + + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + + operator typename Impl::Arguments() const { + return + { + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + + +// For Ptr-Array and Grouped GEMM +// D = activation(alpha * acc + beta * C), where alpha and beta can be vectors for each batch/group +// With Row BlockScaleFactor Generation, separate tensors per batch/group. +template< + int SFVecsize, + class EpilogueTile, + class CtaTileShapeMNK, + int FragmentSize, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm120LinCombEltActRowBlockScaleFactorPtrArray = + Sm90EVT< + Sm120BlockScaleFactorRowStore< + SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, + ElementCompute, ElementBlockScaleFactor *, RoundStyle + >, // gen scalefactor + Sm90LinCombEltActPtrArray // activation(beta * C + (alpha * acc)) + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm120PtrArrayTmaWarpSpecialized, + fusion::LinCombEltActBlockScaleFactor< + ActivationFn, SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementSource, ElementScalar, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm120LinCombEltActRowBlockScaleFactorPtrArray< + SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle + > { + + using Impl = + Sm120LinCombEltActRowBlockScaleFactorPtrArray< + SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle + >; + + using Operation = + fusion::LinCombEltActBlockScaleFactor< + ActivationFn, SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementSource, ElementScalar, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementScalar const* const* alpha_ptr_array = nullptr; + ElementScalar const* const* beta_ptr_array = nullptr; + ElementBlockScaleFactor ** block_scale_factor_ptr = nullptr; + + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }, // end unary op + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; } // namespace cutlass::epilogue::fusion ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp index 59a9d03026..e72e971bd8 100644 --- a/include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp @@ -94,6 +94,8 @@ struct Sm120BlockScaleFactorRowStore { using Params = Arguments; + using UnderlyingElementBlockScaleFactor = cute::remove_pointer_t; + template static constexpr Params to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { @@ -390,21 +392,21 @@ struct Sm120BlockScaleFactorRowStore { } ElementCompute pvscale = mul(amax, norm_constant_scaled_down); - ElementBlockScaleFactor qpvscale = NumericConverter{}(pvscale); + UnderlyingElementBlockScaleFactor qpvscale = NumericConverter{}(pvscale); tC_rSFD_flt(coord) = qpvscale; // // Apply the scale factor to the output // ElementCompute qpvscale_rcp = [&]() { - if constexpr (cute::is_same_v) { + if constexpr (cute::is_same_v) { // UE8M0: Use integer subtraction to do the fast rcp in ue8m0 and then convert to float. - auto e8m0_qpvscale_rcp = cutlass::reciprocal_approximate{}(qpvscale); - return cutlass::NumericConverter{}(e8m0_qpvscale_rcp); + auto e8m0_qpvscale_rcp = cutlass::reciprocal_approximate{}(qpvscale); + return cutlass::NumericConverter{}(e8m0_qpvscale_rcp); } else { // UE4M3: Do the rcp in fp32 data type. - auto qpvscale_up = cutlass::NumericConverter{}(qpvscale); + auto qpvscale_up = cutlass::NumericConverter{}(qpvscale); return cutlass::reciprocal_approximate_ftz{}(qpvscale_up); } }(); @@ -458,15 +460,24 @@ struct Sm120BlockScaleFactorRowStore { auto [M, N, K, L] = args.problem_shape_mnkl; auto [m, n, k, l] = args.tile_coord_mnkl; using Sm1xxBlockScaledOutputConfig = cutlass::detail::Sm1xxBlockScaledOutputConfig; + UnderlyingElementBlockScaleFactor* ptr_scale_factor = nullptr; + // If Ptr-Array/Grouped GEMM with BlockScaleFactor per batch/group + if constexpr (!cute::is_same_v) { + ptr_scale_factor = params_ptr->ptr_scale_factor[l]; + l = 0; + } + else { + ptr_scale_factor = params_ptr->ptr_scale_factor; + } auto epi_tile_mn = shape<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)); - Tensor mSFD = make_tensor(make_gmem_ptr(params_ptr->ptr_scale_factor), Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(args.problem_shape_mnkl)); + Tensor mSFD = make_tensor(make_gmem_ptr(ptr_scale_factor), Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(args.problem_shape_mnkl)); static_assert(size<1>(EpilogueTile{}) && ((size<1>(EpilogueTile{}) & (size<1>(EpilogueTile{}) - 1)) == 0), "Epilogue Tile N should be pow of 2"); Tensor gSFD = local_tile(mSFD, args.epi_tile, make_coord(_, _,l)); // (EPI_M,EPI_N, #EPI_Ms, #EPI_Ns) Tensor tCgSFD = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,#EPI_Ms, #EPI_Ns) gSFD, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrSFD = make_tensor_like(take<0,3>(cute::layout(tCgSFD))); // (CPY,CPY_M,CPY_N) + Tensor tCrSFD = make_tensor_like(take<0,3>(cute::layout(tCgSFD))); // (CPY,CPY_M,CPY_N) auto tile_coord_mn = make_coord(m * size<0>(epi_tile_mn), n * size<1>(epi_tile_mn)); @@ -537,6 +548,8 @@ struct Sm120BlockScaleFactorColStore { }; using Params = Arguments; + using UnderlyingElementBlockScaleFactor = cute::remove_pointer_t; + template static constexpr Params to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { @@ -770,21 +783,21 @@ struct Sm120BlockScaleFactorColStore { synchronize(); ElementCompute pvscale = mul(amax, norm_constant_scaled_down); - ElementBlockScaleFactor qpvscale = NumericConverter{}(pvscale); + UnderlyingElementBlockScaleFactor qpvscale = NumericConverter{}(pvscale); filter(tC_rSFD)(sf_id + mma_in_epi*ColsPerThreadAccFrag) = qpvscale; // // Apply the scale factor to the output // ElementCompute qpvscale_rcp = [&]() { - if constexpr (cute::is_same_v) { + if constexpr (cute::is_same_v) { // UE8M0: Use integer subtraction to do the fast rcp in ue8m0 and then convert to float. - auto e8m0_qpvscale_rcp = cutlass::reciprocal_approximate{}(qpvscale); - return cutlass::NumericConverter{}(e8m0_qpvscale_rcp); + auto e8m0_qpvscale_rcp = cutlass::reciprocal_approximate{}(qpvscale); + return cutlass::NumericConverter{}(e8m0_qpvscale_rcp); } else { // UE4M3: Do the rcp in fp32 data type. - auto qpvscale_up = cutlass::NumericConverter{}(qpvscale); + auto qpvscale_up = cutlass::NumericConverter{}(qpvscale); return cutlass::reciprocal_approximate_ftz{}(qpvscale_up); } }(); @@ -829,18 +842,27 @@ struct Sm120BlockScaleFactorColStore { auto [M, N, K, L] = args.problem_shape_mnkl; auto [m, n, k, l] = args.tile_coord_mnkl; using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig; + UnderlyingElementBlockScaleFactor* ptr_scale_factor = nullptr; + // If Ptr-Array/Grouped GEMM with BlockScaleFactor per batch/group + if constexpr (!cute::is_same_v) { + ptr_scale_factor = params_ptr->ptr_scale_factor[l]; + l = 0; + } + else { + ptr_scale_factor = params_ptr->ptr_scale_factor; + } static_assert(size<0>(EpilogueTile{}) && ((size<0>(EpilogueTile{}) & (size<1>(EpilogueTile{}) - 1)) == 0), "Epilogue Tile N should be pow of 2"); auto epi_tile_mn = shape<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)); - Tensor mSFD = make_tensor(make_gmem_ptr(params_ptr->ptr_scale_factor), + Tensor mSFD = make_tensor(make_gmem_ptr(ptr_scale_factor), Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(args.problem_shape_mnkl)); Tensor gSFD = local_tile(mSFD, args.epi_tile, make_coord(_, _,l)); // (EPI_M,EPI_N, #EPI_Ms, #EPI_Ns) Tensor tCgSFD = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,#EPI_Ms, #EPI_Ns) gSFD, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrSFD = make_tensor_like(take<0,3>(cute::layout(tCgSFD))); // (CPY,CPY_M,CPY_N) + Tensor tCrSFD = make_tensor_like(take<0,3>(cute::layout(tCgSFD))); // (CPY,CPY_M,CPY_N) auto tile_coord_mn = make_coord(m * size<0>(epi_tile_mn), n * size<1>(epi_tile_mn)); diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp index c498a3829f..cd470f84f7 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp @@ -1191,9 +1191,11 @@ struct Sm90RowBroadcast { auto layout_M = make_layout(M, repeat_like(M, _0{})); auto layout_L = make_layout(L, get<2>(params.dRow)); - ElementInput const* ptr_row; + ElementInput const* ptr_row = nullptr; if constexpr(IsArrayOfPointers) { - ptr_row = params.ptr_row[l]; + if (!(EnableNullptr && params.ptr_row == nullptr)) { + ptr_row = params.ptr_row[l]; + } } else { ptr_row = params.ptr_row; } @@ -1439,9 +1441,11 @@ struct Sm90ColBroadcast { auto layout_N = make_layout(N, repeat_like(N, _0{})); auto layout_L = make_layout(L, get<2>(params.dCol)); - ElementInput const* ptr_col; + ElementInput const* ptr_col = nullptr; if constexpr(IsArrayOfPointers) { - ptr_col = params.ptr_col[l]; + if (!(EnableNullptr && params.ptr_col == nullptr)) { + ptr_col = params.ptr_col[l]; + } } else { ptr_col = params.ptr_col; } diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp index ce841bf28b..93720f8d3d 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp @@ -116,6 +116,172 @@ sm90_partition_for_epilogue( // ///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Producer load callbacks, called by the epilogue load warp. +// Operations usually only define this if TMA load is needed. Most operations will reuse this empy implementation +// Load callbacks are responsible for issuing corresponding mbarrier expect-tx ops for any TMA loads issued, but +// are not responsible for issuing the producer_commit barrier arrival, which is issued by the collective instead +// If this is non-empty, is_producer_load_needed must be true. +// +template +struct ProducerLoadCallbacksImpl { + // Callbacks can store non-persistent variables (e.g. tensors) or copies of persistent variables + CallbacksTuple callbacks_tuple; + + // Before entry of the subtile load loop + CUTLASS_DEVICE void + begin() { + for_each(callbacks_tuple, + [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.begin(); + } + ); + } + + // Entry of the subtile load loop. Aux loads usually performed here + // Upon entry the producer acquire of the current subtile lock has completed. + // Upon exit all TMA loads for this subtile must have been issued, with corresponding expect-tx operations + CUTLASS_DEVICE void + step(uint64_t* full_mbarrier_ptr, int epi_m, int epi_n, int load_iteration, bool issue_tma_load) { + for_each(callbacks_tuple, + [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.step(full_mbarrier_ptr, epi_m, epi_n, load_iteration, issue_tma_load); + } + ); + } + + // Exit of the subtile load loop. + CUTLASS_DEVICE void + end() { + for_each(callbacks_tuple, + [] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.end(); + } + ); + } +}; + + +// +// Consumer store callbacks, called by the epilogue store warps. +// All operations must redefine this, with optional inheritance from this empty implementation. +// +template +struct ConsumerStoreCallbacksImpl { + // Callbacks can store non-persistent variables (e.g. tensors) or copies of persistent variables + CallbacksTuple callbacks_tuple; + + // Before entry of subtile store loop. Gmem broadcasts usually performed here. + CUTLASS_DEVICE void + begin() { + for_each(callbacks_tuple, + [] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.begin(); + } + ); + } + + // Is a thread sync needed after begin(). Allows chaining async copies across multiple nodes + CUTLASS_DEVICE bool + begin_sync_needed() const { + return cute::apply(callbacks_tuple, + [] (auto const&... callbacks) { + return (false || ... || callbacks.begin_sync_needed()); + } + ); + } + + // Start of subtile store iteration + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { + for_each(callbacks_tuple, + [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.begin_loop(epi_m, epi_n); + } + ); + } + + // Before visit callback. Smem broadcasts usually performed here. + // Upon entry, all producer loads for this subtile are completed and visible. + CUTLASS_DEVICE void + previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { + for_each(callbacks_tuple, + [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.previsit(epi_m, epi_n, load_iteration, is_producer_load_needed); + } + ); + } + + // Perform the fused elementwise computation + template + CUTLASS_DEVICE auto // returns an Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const&... frg_inputs) // depends on the N-naryness of the op + = delete; // Must be implemented for each operation + + // After visit call. Smem reductions usually performed here + // reduction_buffer is an arbitrary smem tensor that can be used for workspace + // It is each nodes reponsibility to assert that this buffer is sufficiently sized + // and to ensure that this buffer is no longer needed upon callback exit + // i.e. results are synchronized and no longer in the reduction buffer + // + // visit_results is a rmem tensor that contains the results of visit() for an entire + // on the current epilogue subtile + template + CUTLASS_DEVICE void + reduce(STensor&& reduction_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { + for_each(callbacks_tuple, + [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.reduce(reduction_buffer, sync_fn, epi_m, epi_n, is_last_iteration, visit_results); + } + ); + } + + // After reduce call, before smem async fence. Smem stores usually performed here. + // Upon exit, all smem stores for TMA must have been issued + CUTLASS_DEVICE void + postreduce(int epi_m, int epi_n, int store_iteration, bool issue_smem_store) { + for_each(callbacks_tuple, + [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.postreduce(epi_m, epi_n, store_iteration, issue_smem_store); + } + ); + } + + // After smem async fence, before TMA store commit. Aux stores usually performed here + // Upon exit, all TMA stores for this subtile must have been issued + // Because of the TMA store delay optimization, this entry point must ONLY be used for TMA stores + // other gmem stores can be placed in the reduce or postreduce entry points + CUTLASS_DEVICE void + tma_store(int epi_m, int epi_n, int store_iteration, bool issue_tma_store) { + for_each(callbacks_tuple, + [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.tma_store(epi_m, epi_n, store_iteration, issue_tma_store); + } + ); + } + + // End of subtile store iteration + CUTLASS_DEVICE void + end_loop(int epi_m, int epi_n) { + for_each(callbacks_tuple, + [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.end_loop(epi_m, epi_n); + } + ); + } + + // Exit of subtile store loop. Gmem reductions usually performed here. + CUTLASS_DEVICE void + end() { + for_each(callbacks_tuple, + [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.end(); + } + ); + } +}; + template< class ProblemShapeMNKL, class TileShapeMNK, @@ -349,51 +515,6 @@ struct Sm90VisitorImpl : Sm90VisitorImplBase { ); } - // - // Producer load callbacks, called by the epilogue load warp. - // Operations usually only define this if TMA load is needed. Most operations will reuse this empy implementation - // Load callbacks are responsible for issuing corresponding mbarrier expect-tx ops for any TMA loads issued, but - // are not responsible for issuing the producer_commit barrier arrival, which is issued by the collective instead - // If this is non-empty, is_producer_load_needed must be true. - // - template - struct ProducerLoadCallbacks { - // Callbacks can store non-persistent variables (e.g. tensors) or copies of persistent variables - CallbacksTuple callbacks_tuple; - - // Before entry of the subtile load loop - CUTLASS_DEVICE void - begin() { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.begin(); - } - ); - } - - // Entry of the subtile load loop. Aux loads usually performed here - // Upon entry the producer acquire of the current subtile lock has completed. - // Upon exit all TMA loads for this subtile must have been issued, with corresponding expect-tx operations - CUTLASS_DEVICE void - step(uint64_t* full_mbarrier_ptr, int epi_m, int epi_n, int load_iteration, bool issue_tma_load) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.step(full_mbarrier_ptr, epi_m, epi_n, load_iteration, issue_tma_load); - } - ); - } - - // Exit of the subtile load loop. - CUTLASS_DEVICE void - end() { - for_each(callbacks_tuple, - [] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.end(); - } - ); - } - }; - // Producer load callbacks factory // All operations must redefine this, but most can just dispatch to the base impl template @@ -405,131 +526,11 @@ struct Sm90VisitorImpl : Sm90VisitorImplBase { }, [] (auto&&... callbacks) CUTLASS_LAMBDA_FUNC_INLINE { auto callbacks_tuple = cute::make_tuple(callbacks...); - return ProducerLoadCallbacks{callbacks_tuple}; + return ProducerLoadCallbacksImpl{callbacks_tuple}; } ); } - // - // Consumer store callbacks, called by the epilogue store warps. - // All operations must redefine this, with optional inheritance from this empty implementation. - // - template - struct ConsumerStoreCallbacks { - // Callbacks can store non-persistent variables (e.g. tensors) or copies of persistent variables - CallbacksTuple callbacks_tuple; - - // Before entry of subtile store loop. Gmem broadcasts usually performed here. - CUTLASS_DEVICE void - begin() { - for_each(callbacks_tuple, - [] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.begin(); - } - ); - } - - // Is a thread sync needed after begin(). Allows chaining async copies across multiple nodes - CUTLASS_DEVICE bool - begin_sync_needed() const { - return cute::apply(callbacks_tuple, - [] (auto const&... callbacks) { - return (false || ... || callbacks.begin_sync_needed()); - } - ); - } - - // Start of subtile store iteration - CUTLASS_DEVICE void - begin_loop(int epi_m, int epi_n) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.begin_loop(epi_m, epi_n); - } - ); - } - - // Before visit callback. Smem broadcasts usually performed here. - // Upon entry, all producer loads for this subtile are completed and visible. - CUTLASS_DEVICE void - previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.previsit(epi_m, epi_n, load_iteration, is_producer_load_needed); - } - ); - } - - // Perform the fused elementwise computation - template - CUTLASS_DEVICE auto // returns an Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, - Array const&... frg_inputs) // depends on the N-naryness of the op - = delete; // Must be implemented for each operation - - // After visit call. Smem reductions usually performed here - // reduction_buffer is an arbitrary smem tensor that can be used for workspace - // It is each nodes reponsibility to assert that this buffer is sufficiently sized - // and to ensure that this buffer is no longer needed upon callback exit - // i.e. results are synchronized and no longer in the reduction buffer - // - // visit_results is a rmem tensor that contains the results of visit() for an entire - // on the current epilogue subtile - template - CUTLASS_DEVICE void - reduce(STensor&& reduction_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.reduce(reduction_buffer, sync_fn, epi_m, epi_n, is_last_iteration, visit_results); - } - ); - } - - // After reduce call, before smem async fence. Smem stores usually performed here. - // Upon exit, all smem stores for TMA must have been issued - CUTLASS_DEVICE void - postreduce(int epi_m, int epi_n, int store_iteration, bool issue_smem_store) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.postreduce(epi_m, epi_n, store_iteration, issue_smem_store); - } - ); - } - - // After smem async fence, before TMA store commit. Aux stores usually performed here - // Upon exit, all TMA stores for this subtile must have been issued - // Because of the TMA store delay optimization, this entry point must ONLY be used for TMA stores - // other gmem stores can be placed in the reduce or postreduce entry points - CUTLASS_DEVICE void - tma_store(int epi_m, int epi_n, int store_iteration, bool issue_tma_store) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.tma_store(epi_m, epi_n, store_iteration, issue_tma_store); - } - ); - } - - // End of subtile store iteration - CUTLASS_DEVICE void - end_loop(int epi_m, int epi_n) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.end_loop(epi_m, epi_n); - } - ); - } - - // Exit of subtile store loop. Gmem reductions usually performed here. - CUTLASS_DEVICE void - end() { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.end(); - } - ); - } - }; - // Consumer store callbacks factory // All operations must redefine this template < @@ -544,7 +545,7 @@ struct Sm90VisitorImpl : Sm90VisitorImplBase { }, [] (auto&&... callbacks) CUTLASS_LAMBDA_FUNC_INLINE { auto callbacks_tuple = cute::make_tuple(callbacks...); - return ConsumerStoreCallbacks{callbacks_tuple}; + return ConsumerStoreCallbacksImpl{callbacks_tuple}; } ); } @@ -553,8 +554,8 @@ struct Sm90VisitorImpl : Sm90VisitorImplBase { ///////////////////////////////////////////////////////////////////////////////////////////////// // Convenience aliases -using EmptyProducerLoadCallbacks = Sm90VisitorImpl<>::ProducerLoadCallbacks>; -using EmptyConsumerStoreCallbacks = Sm90VisitorImpl<>::ConsumerStoreCallbacks>; +using EmptyProducerLoadCallbacks = ProducerLoadCallbacksImpl>; +using EmptyConsumerStoreCallbacks = ConsumerStoreCallbacksImpl>; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -614,9 +615,9 @@ struct Sm90TreeVisitor : Sm90VisitorImpl { > CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - auto callbacks_tuple = Sm90VisitorImpl:: + auto callbacks_impl = Sm90VisitorImpl:: template get_consumer_store_callbacks(args); - return ConsumerStoreCallbacks(std::move(callbacks_tuple)); + return ConsumerStoreCallbacks(cute::move(callbacks_impl)); } }; @@ -663,9 +664,9 @@ struct Sm90SplitTreeVisitor : Sm90VisitorImpl CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - auto callbacks_tuple = Sm90VisitorImpl:: + auto callbacks_impl = Sm90VisitorImpl:: template get_consumer_store_callbacks(args); - return ConsumerStoreCallbacks(std::move(callbacks_tuple)); + return ConsumerStoreCallbacks(cute::move(callbacks_impl)); } }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -739,9 +740,9 @@ struct Sm90TopologicalVisitor : Sm90VisitorImpl { > CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - auto callbacks_tuple = Sm90VisitorImpl:: + auto callbacks_impl = Sm90VisitorImpl:: template get_consumer_store_callbacks(args); - return ConsumerStoreCallbacks(std::move(callbacks_tuple)); + return ConsumerStoreCallbacks(cute::move(callbacks_impl)); } }; diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 04935e3421..44c606c4ea 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -52,6 +52,18 @@ namespace thread { ///////////////////////////////////////////////////////////////////////////////////////////////// +// If kIsHeavy is a member, use it. Otherwise, assume that it's false. +template +struct kIsHeavy_member_or_false { + static constexpr bool value = false; +}; +template +struct kIsHeavy_member_or_false::type> { + static constexpr bool value = Op::kIsHeavy; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + // Identity operator template struct Identity { @@ -113,6 +125,8 @@ template