Skip to content

Cutlass 3.9.2 #371

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
6f49218
v3.9 update (#2203)
yzhaiustc Apr 2, 2025
79fc51f
v3.9 update (#2213)
yzhaiustc Apr 3, 2025
df8a550
Update mma_atom.hpp (#2159)
liujshi Apr 3, 2025
09df6ac
[Doc]fix typo (#2174)
liwenju0 Apr 10, 2025
19cc2a5
add support for sm89 in cute and the unit tests (#2177)
kf-zhang Apr 10, 2025
dd76dec
[Doc] Make C++ code more plausible (#2156)
keryell Apr 10, 2025
5120b21
suppress compilation warnings (#2195)
reed-lau Apr 10, 2025
9e1b649
fix-left-inverse-for-nvcc114 (#2196)
reed-lau Apr 10, 2025
b3f3c77
Update tile_iterator.cu (#2204)
Lulullama405 Apr 10, 2025
5e49724
fix: fig link in cute docs (#2216)
Zhang-kg Apr 10, 2025
bb4dd68
Fix broken links and alt text in cluster launch control docs (#2234)
milesvant Apr 21, 2025
ade6376
[SM90] Change register allocation for TileN=208 to avoid spills (#2219)
tridao Apr 21, 2025
81a43e6
Set EpiTile correctly when TileN is not divisible by 32 (#2220)
tridao Apr 21, 2025
8e345c5
fix_missing_stdint (#2199)
wu-kan Apr 24, 2025
331a1f5
cutlass 3.9 update (#2255)
yzhaiustc Apr 24, 2025
f02a7c2
Update README.md for 3.9
hwu36 Apr 24, 2025
be73ad2
Update CHANGELOG.md for 3.9
hwu36 Apr 24, 2025
e94e888
Update CHANGELOG.md
hwu36 Apr 25, 2025
6971260
fix blackwell grouped groupwise hang (#2267)
hwu36 Apr 29, 2025
2b78c2f
cherry-pick feature/hopper-blockwise-generalization-optimization (#2270)
IwakuraRein Apr 29, 2025
e5b810b
Use cudaMemcpyAsync in gemm grouped with kRequiresPrecomputation sche…
HydraQYH Apr 30, 2025
35136f5
Fix wrong detection of python version for `use_rmm`. (#2224)
eliphatfs Apr 30, 2025
fe75ead
Import pydot lazily (#2248)
mlazos Apr 30, 2025
b3ce7e1
Make cc a positional argument (#2249)
mlazos Apr 30, 2025
c4bdfe8
Lazy scipy import (#2250)
mlazos Apr 30, 2025
e3cb8a7
Import cuda, cudart, nvrtc lazily (#2251)
mlazos May 1, 2025
f535c33
3.9.1 doc/version change (#2273)
hwu36 May 1, 2025
89f6bf2
Fix group scale gemm when K==128 (#2275)
x86vk May 2, 2025
40f124e
[CUTLASS] Add GNA to PUBLICATIONS.md (#2276)
alihassanijr May 2, 2025
ad7b2f5
3.9.2 doc/version (#2279)
hwu36 May 4, 2025
3dfe47b
Merge remote-tracking branch 'codeplay/main' into aacosta/3.9.2
aacostadiaz May 13, 2025
443c793
Use gpu_generics functions
aacostadiaz May 14, 2025
aedee57
Add Xe Group Scheduler
muhammad-tanvir-1211 May 14, 2025
93b3324
Merge pull request #1 from muhammad-tanvir-1211/xe_group_scheduler
aacostadiaz May 14, 2025
692e584
Fix tests
aacostadiaz May 20, 2025
33b332f
Merge remote-tracking branch 'origin/aacosta/3.9.2' into aacosta/3.9.2
aacostadiaz May 20, 2025
b53d801
Merge branch 'sycl-develop' into aacosta/3.9.2
aacostadiaz May 20, 2025
c2efea2
fix python
aacostadiaz May 20, 2025
3e47ace
Merge remote-tracking branch 'origin/aacosta/3.9.2' into aacosta/3.9.2
aacostadiaz May 20, 2025
a121708
fix python
aacostadiaz May 20, 2025
f909ce4
fix python
aacostadiaz May 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
87 changes: 58 additions & 29 deletions CHANGELOG.md

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,7 @@ target_include_directories(
CUTLASS
SYSTEM INTERFACE
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include>
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include/cccl>
)

install(
Expand Down
2 changes: 2 additions & 0 deletions PUBLICATIONS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

- ["ParetoQ: Scaling Laws in Extremely Low-bit LLM Quantization"](https://arxiv.org/abs/2502.02631). Zechun Liu, Changsheng Zhao, Hanxian Huang, Sijia Chen, Jing Zhang, Jiawei Zhao, Scott Roy, Lisa Jin, Yunyang Xiong, Yangyang Shi, Lin Xiao, Yuandong Tian, Bilge Soran, Raghuraman Krishnamoorthi, Tijmen Blankevoort, Vikas Chandra. _arXiv_, February 2025.

- ["Generalized Neighborhood Attention: Multi-dimensional Sparse Attention at the Speed of Light"](https://arxiv.org/abs/2504.16922). Ali Hassani, Fengzhe Zhou, Aditya Kane, Jiannan Huang, Chieh-Yun Chen, Min Shi, Steven Walton, Markus Hoehnerbach, Vijay Thakkar, Michael Isaev, Qinsheng Zhang, Bing Xu, Haicheng Wu, Wen-mei Hwu, Ming-Yu Liu, Humphrey Shi. _arXiv_, April 2025.

## 2024

- ["DeepSeek-V3 Technical Report"](https://arxiv.org/abs/2412.19437). DeepSeek-AI. _arXiv_, December 2024.
Expand Down
80 changes: 49 additions & 31 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")

# CUTLASS 3.9.0
# CUTLASS 3.9.2

_CUTLASS 3.9.0 - March 2025_
_CUTLASS 3.9.2 - May 2025_

**This repository fast-follows NVIDIA CUTLASS repository adding SYCL support for Intel GPUs.**
The CUDA support is unmodified from upstream and can be used interchangeably.
Expand Down Expand Up @@ -39,9 +39,9 @@ the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution
operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline.
This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components.

See the [Quick Start Guide](./media/docs/quickstart.md) to get started quickly.
See the [Quick Start Guide](./media/docs/cpp/quickstart.md) to get started quickly.

See the [functionality docs](./media/docs/functionality.md) for a more comprehensive
See the [functionality docs](./media/docs/cpp/functionality.md) for a more comprehensive
list of kernel level features, data types, instructions, and minimum supported by CUTLASS on each GPU
architecture.

Expand All @@ -57,18 +57,35 @@ architecture.
- [Blockscaled GEMM with NVFP4 input datatype and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu).
- [Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor with scale factor generation](./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu).
- [Blockscaled GEMM with mixed input datatype (MXFP8 and MXFP6) and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu).
- [Grouped GEMM with nvfp4 datatype](./examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu).
- [Sparse Blockscaled GEMM with mxfp8 input datatype and BF16 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu).
- [Sparse Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu).
* Set of unit tests that demonstrate the usage of both [sparse](./test/unit/gemm/device/sm120_blockscaled_sparse_tensorop_gemm/) and [dense](./test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/) Blackwell SM120 blockscaled GEMM.
* Support for Blackwell SM100 Sparse kernels:
- Collective mainloop that target for
* [SM100 Sparse GEMM](./include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp)
* Set of example that demonstrate the usage of the 3.x API for targeting Blackwell SM100 Sparse GEMM:
- [Sparse GEMM](./examples/83_blackwell_sparse_gemm/83_blackwell_sparse_gemm.cu)
- [Blockscaled Sparse GEMM with NVFP4 input data type](./examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm.cu)
- [Blockscaled Sparse GEMM with mixed input data type (MXFP8 and MXFP4)](./examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu)
* Set of unit tests that demonstrate the usage of [sparse](./test/unit/gemm/device/sm100_sparse_tensorop_gemm) and [blockscaled sparse](./test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm) Blackwell SM100 GEMM.
* A new Multi-head Latent Attention (MLA) for SM100 Blackwell architecture in CUTLASS [example](./examples/77_blackwell_fmha/) covers the flashMLA-like weight-absorbed decoding use-case.
* A new FMHA Backward kernel for SM100 Blackwell architecture extends CUTLASS [example](./examples/77_blackwell_fmha/) to show how the five backward pass MMAs can be fused into a single kernel to achieve high performance.
* A new [distributed GEMM example](./examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu) for SM100 Blackwell architecture.
* Enhancement and new support of block-wise and group-wise GEMM for Hopper and Blackwell architectures:
- Enhancement of [blockwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) for Hopper architecture.
- Enhancement of [groupwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) for Hopper architecture.
- Support for [grouped GEMM with blockwise scaling](./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture.
- Support for [grouped GEMM with blockwise and groupwise scaling](./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture.
- Support for [grouped-wise GEMM](./tools/profiler/src/blockwise_gemm_operation_profiler.cu) in CUTLASS profiler.
- Support for [blockwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu) for Blackwell architecture.
- Support for [groupwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu) for Blackwell architecture.
* Added support for enhanced kernel performance search in CUTLASS:
- Support for [grouped GEMM with blockwise](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_blockwise.cu) and [groupwise scaling](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_groupwise.cu) for Blackwell architecture.
* Added support for enhanced kernel performance search (auto-tuning) in CUTLASS profiler:
- Sorting performance results by GFLOPs/second: Users can now sort the final performance report based on GFLOPs/second, making it easier to identify the most efficient kernels.
- Exhaustive search for best kernel performance in GFLOPs/second: The profiler now searches for the best-performing kernel across a range of problem sizes, swizzle sizes, rasterization orders, and dynamic cluster configurations to maximize performance.
- Performance search under a fixed GEMM shape: Enables exhaustive tuning within a fixed GEMM shape, exploring various kernel parameters to find the best configuration.
- More detailed introductions and examples to leverage this feature can be found in [profiler.md](./media/docs/profiler.md#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss).
- More detailed introductions and examples to leverage this feature can be found in [profiler.md](./media/docs/cpp/profiler.md#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss).
* Support `void` as the D element in sm100 kernel epilogues.

Note: CUTLASS 3.x builds are known to be down on Windows platforms for all CUDA toolkits.
CUTLASS team is working on a fix.
Expand Down Expand Up @@ -115,7 +132,7 @@ Layouts can also be combined and manipulated via functional composition, on whic
CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates.
This greatly simplifies the design and improves code composability and readability.
More documentation specific to CuTe can be found in its
[dedicated documentation directory](./media/docs/cute/00_quickstart.md).
[dedicated documentation directory](./media/docs/cpp/cute/00_quickstart.md).

# Compatibility

Expand Down Expand Up @@ -162,6 +179,7 @@ CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be
|NVIDIA H100 Tensor Core GPU |9.0|11.8|
|NVIDIA H200 Tensor Core GPU |9.0|11.8|
|NVIDIA B200 Tensor Core GPU |10.0|12.8|
|NVIDIA GeForce RTX 50x0 series |10.0|12.8|

## Target Architecture

Expand Down Expand Up @@ -197,30 +215,30 @@ NVIDIA Blackwell GeForce RTX 50 series GPUs. As a result, kernels
compiled for Blackwell SM100 architecture with arch conditional features
(using `sm100a`) are not compatible with RTX 50 series GPUs.

Please refer to the [functionality documentation](./media/docs/functionality.md)
Please refer to the [functionality documentation](./media/docs/cpp/functionality.md)
for details on which kernels require which target architectures.

# Documentation

CUTLASS is described in the following documents and the accompanying
[Doxygen documentation](https://nvidia.github.io/cutlass).

- [Quick Start Guide](./media/docs/quickstart.md) - basics of building and running CUTLASS
- [Functionality](./media/docs/functionality.md) - summarizes functionality available in CUTLASS
- [Efficient GEMM in CUDA](./media/docs/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA
- [CUTLASS 3.x Design](./media/docs/cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components
- [GEMM API 3.x](./media/docs/gemm_api_3x.md) - describes the CUTLASS 3.x GEMM model and C++ template concepts
- [GEMM API 2.x](./media/docs/gemm_api.md) - describes the CUTLASS 2.x GEMM model and C++ template concepts
- [Implicit GEMM Convolution](./media/docs/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS
- [Code Organization](./media/docs/code_organization.md) - describes the organization and contents of the CUTLASS project
- [Terminology](./media/docs/terminology.md) - describes terms used in the code
- [Programming Guidelines](./media/docs/programming_guidelines.md) - guidelines for writing efficient modern CUDA C++
- [Fundamental types](./media/docs/fundamental_types.md) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays
- [Layouts](./media/docs/layout.md) - describes layouts of matrices and tensors in memory
- [Tile Iterators](./media/docs/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory
- [CUTLASS Profiler](./media/docs/profiler.md) - command-line driven profiling application
- [CUTLASS Utilities](./media/docs/utilities.md) - additional templates used to facilitate rapid development
- [Dependent kernel launch](./media/docs/dependent_kernel_launch.md) - describes a new feature in Hopper which allows overlapping dependent
- [Quick Start Guide](./media/docs/cpp/quickstart.md) - basics of building and running CUTLASS
- [Functionality](./media/docs/cpp/functionality.md) - summarizes functionality available in CUTLASS
- [Efficient GEMM in CUDA](./media/docs/cpp/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA
- [CUTLASS 3.x Design](./media/docs/cpp/cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components
- [GEMM API 3.x](./media/docs/cpp/gemm_api_3x.md) - describes the CUTLASS 3.x GEMM model and C++ template concepts
- [GEMM API 2.x](./media/docs/cpp/gemm_api.md) - describes the CUTLASS 2.x GEMM model and C++ template concepts
- [Implicit GEMM Convolution](./media/docs/cpp/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS
- [Code Organization](./media/docs/cpp/code_organization.md) - describes the organization and contents of the CUTLASS project
- [Terminology](./media/docs/cpp/terminology.md) - describes terms used in the code
- [Programming Guidelines](./media/docs/cpp/programming_guidelines.md) - guidelines for writing efficient modern CUDA C++
- [Fundamental types](./media/docs/cpp/fundamental_types.md) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays
- [Layouts](./media/docs/cpp/layout.md) - describes layouts of matrices and tensors in memory
- [Tile Iterators](./media/docs/cpp/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory
- [CUTLASS Profiler](./media/docs/cpp/profiler.md) - command-line driven profiling application
- [CUTLASS Utilities](./media/docs/cpp/utilities.md) - additional templates used to facilitate rapid development
- [Dependent kernel launch](./media/docs/cpp/dependent_kernel_launch.md) - describes a new feature in Hopper which allows overlapping dependent
kernels in the same stream, and how it is used in CUTLASS.

# Resources
Expand All @@ -240,7 +258,7 @@ projects. Client applications should target CUTLASS's `include/` directory in th
paths.

CUTLASS unit tests, examples, and utilities can be build with CMake.
The minimum version of CMake is given in the [Quickstart guide](./media/docs/quickstart.md).
The minimum version of CMake is given in the [Quickstart guide](./media/docs/cpp/quickstart.md).
Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed
on your system.

Expand Down Expand Up @@ -285,7 +303,7 @@ CUTLASS is arranged as a header-only library along with Utilities, Tools, Exampl
and template concepts defined in the CUTLASS project.

A detailed explanation of the source code organization may be found in the
[CUTLASS documentation](./media/docs/code_organization.md), but several main components are summarized below.
[CUTLASS documentation](./media/docs/cpp/code_organization.md), but several main components are summarized below.

## CUTLASS Template Library

Expand Down Expand Up @@ -359,7 +377,7 @@ tools/
The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate
basic usage of Core API components and complete tests of the CUTLASS GEMM computations.

Instructions for building and running the Unit tests are described in the [Quickstart guide](./media/docs/quickstart.md).
Instructions for building and running the Unit tests are described in the [Quickstart guide](./media/docs/cpp/quickstart.md).

# Performance Profiling

Expand Down Expand Up @@ -575,9 +593,9 @@ reference_device: Passed

## More Details on Compiling CUTLASS Kernels and CUTLASS Profiler
- Please follow the links for more CMake examples on selectively compiling CUTLASS kernels:
- [GEMM CMake Examples](./media/docs/quickstart.md#gemm-cmake-examples)
- [Implicit GEMM convolution CMake Examples](./media/docs/quickstart.md#convolution-cmake-examples)
- [Further details about the CUTLASS Profiler are described here.](./media/docs/profiler.md)
- [GEMM CMake Examples](./media/docs/cpp/quickstart.md#gemm-cmake-examples)
- [Implicit GEMM convolution CMake Examples](./media/docs/cpp/quickstart.md#convolution-cmake-examples)
- [Further details about the CUTLASS Profiler are described here.](./media/docs/cpp/profiler.md)


# About
Expand Down
6 changes: 3 additions & 3 deletions examples/04_tile_iterator/tile_iterator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
addressable memory, and then store it back into addressable memory.

TileIterator is a core concept in CUTLASS that enables efficient loading and storing of data to
and from addressable memory. The PredicateTileIterator accepts a ThreadMap type, which defines
and from addressable memory. The PredicatedTileIterator accepts a ThreadMap type, which defines
the mapping of threads to a "tile" in memory. This separation of concerns enables user-defined
thread mappings to be specified.

Expand Down Expand Up @@ -124,7 +124,7 @@ __global__ void copy(

cudaError_t TestTileIterator(int M, int K) {

// For this example, we chose a <64, 4> tile shape. The PredicateTileIterator expects
// For this example, we chose a <64, 4> tile shape. The PredicatedTileIterator expects
// PitchLinearShape and PitchLinear layout.
using Shape = cutlass::layout::PitchLinearShape<64, 4>;
using Layout = cutlass::layout::PitchLinear;
Expand All @@ -136,7 +136,7 @@ cudaError_t TestTileIterator(int M, int K) {
// dimension then along the strided dimension.
using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap<Shape, kThreads>;

// Define the PredicateTileIterator, using TileShape, Element, Layout, and ThreadMap types
// Define the PredicatedTileIterator, using TileShape, Element, Layout, and ThreadMap types
using Iterator = cutlass::transform::threadblock::PredicatedTileIterator<
Shape, Element, Layout, 1, ThreadMap>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ struct Options : MixedDtypeOptions{
void initialize(Options const& options) {

auto shape_B = cute::make_shape(options.n, options.k, options.l);
int const scale_k = (options.k + options.g - 1) / options.g;
int const scale_k = cutlass::ceil_div(options.k, options.g);
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
// Reverse stride here due to swap and transpose
Expand All @@ -429,7 +429,7 @@ void initialize(Options const& options) {
block_zero.reset(scale_k * options.l * options.n);

initialize_tensor(block_A, seed + 2022);
initialize_quant_tensor(block_B, seed + 2021);
initialize_tensor(block_B, seed + 2021);
initialize_tensor(block_C, seed + 2020);
initialize_scale(block_scale, options);
initialize_zero(block_zero, options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ struct Options : MixedDtypeOptions {
void initialize(Options const& options) {

auto shape_B = cute::make_shape(options.n, options.k, options.l);
int const scale_k = (options.k + options.g - 1) / options.g;
int const scale_k = cutlass::ceil_div(options.k, options.g);
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
// Reverse stride here due to swap and transpose
Expand Down Expand Up @@ -347,7 +347,7 @@ void initialize(Options const& options) {
block_zero.reset(scale_k * options.l * options.n);

initialize_tensor(block_A, seed + 2022);
initialize_quant_tensor(block_B, seed + 2021);
initialize_tensor(block_B, seed + 2021);
cutlass::unified_encode_int4b(block_B.get(), block_B_modified.get(), block_B.size());
initialize_tensor(block_C, seed + 2020);
initialize_scale(block_scale, options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ cutlass::DeviceAllocation<typename GemmScaleWithZeroPoint::EpilogueOutputOp::Ele
void initialize(MixedDtypeOptions const& options) {

auto shape_b = cute::make_shape(options.n, options.k, options.l);
int const scale_k = (options.k + options.g - 1) / options.g;
int const scale_k = cutlass::ceil_div(options.k, options.g);
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_b);
// Reverse stride here due to swap and transpose
Expand All @@ -313,7 +313,7 @@ void initialize(MixedDtypeOptions const& options) {
block_zero.reset(scale_k * options.l * options.n);

initialize_tensor(block_A, seed + 2022);
initialize_quant_tensor(block_B, seed + 2021);
initialize_tensor(block_B, seed + 2021);
initialize_tensor(block_C, seed + 2020);
initialize_scale(block_scale, options);
initialize_zero(block_zero, options);
Expand Down
Loading
Loading