Skip to content

Commit 15a5738

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Enable preshuffled mixed dtype Cutlass Gemm (pytorch#3722)
Summary: Enable new preshuffled FP8 x I4 kernels. These are the most performant mixed dtype kernels to date and dramatically outperform prior approaches including those in FBGEMM, marlin, and Machete. Differential Revision: D69955197
1 parent 504b98b commit 15a5738

File tree

4 files changed

+452
-0
lines changed

4 files changed

+452
-0
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,6 +1276,54 @@ def cuda(self) -> bool:
12761276
return True
12771277

12781278

1279+
@register_quantize_op
1280+
class F8I4ShuffledGemm(F8I4RowwiseGemm):
1281+
def _int4_row_quantize(
1282+
self,
1283+
x: torch.Tensor,
1284+
group_size: int = 128,
1285+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1286+
n_bit = 4 # Number of target bits.
1287+
to_quant = x.reshape(-1, group_size).to(torch.float)
1288+
1289+
max_val = torch.abs(to_quant).amax(dim=1, keepdim=True)
1290+
max_int = 2 ** (n_bit - 1)
1291+
min_int = -(2 ** (n_bit - 1))
1292+
scales = max_val.clamp(min=1e-6) / max_int
1293+
1294+
out = to_quant.div(scales).round().clamp_(min_int, max_int - 1)
1295+
1296+
# Cast to int8 and restore shape.
1297+
out = out.to(dtype=torch.int8).reshape(x.shape)
1298+
1299+
# Scales should be in [num_groups, N] layout.
1300+
scales = scales.view(x.shape[0], -1).t().contiguous()
1301+
1302+
return out, scales
1303+
1304+
def quantize(self, x, w):
1305+
# Quantize both input tensors.
1306+
xq, x_scale = quantize_fp8_row(x)
1307+
wq, w_scale = self._int4_row_quantize(w)
1308+
# Pack int4 values together.
1309+
wq = self._pack_int4(wq)
1310+
# Shuffle weights and scales for faster compute.
1311+
wq, w_scale = torch.ops.fbgemm.preshuffle_i4(wq, w_scale)
1312+
return xq, wq, x_scale, w_scale
1313+
1314+
def compute(self, xq, wq, x_scale, w_scale):
1315+
out = torch.ops.fbgemm.f8i4bf16_shuffled(xq, wq, x_scale, w_scale)
1316+
return out
1317+
1318+
def quantize_and_compute(self, x, w):
1319+
xq, wq, x_scale, w_scale = self.quantize(x, w)
1320+
return self.compute(xq, wq, x_scale, w_scale)
1321+
1322+
@property
1323+
def name(self) -> str:
1324+
return "cutlass_f8i4_preshuffle"
1325+
1326+
12791327
@register_quantize_op
12801328
class BF16I4RowwiseGemm(F8I4RowwiseGemm):
12811329
"""
Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <ATen/ATen.h>
10+
#include <ATen/cuda/CUDAContext.h>
11+
12+
#include "cutlass/cutlass.h"
13+
14+
#include "cute/tensor.hpp"
15+
#include "cutlass/epilogue/collective/collective_builder.hpp"
16+
#include "cutlass/epilogue/collective/default_epilogue.hpp"
17+
#include "cutlass/gemm/collective/collective_builder.hpp"
18+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
19+
#include "cutlass/gemm/dispatch_policy.hpp"
20+
#include "cutlass/gemm/kernel/gemm_universal.hpp"
21+
22+
#include "cutlass/util/mixed_dtype_utils.hpp"
23+
#include "cutlass/util/packed_stride.hpp"
24+
25+
#include "cutlass_extensions/include/kernel_mode.h"
26+
27+
namespace fbgemm_gpu {
28+
29+
#if CUDART_VERSION >= 12000
30+
31+
template <int TB_M, int TB_N, int TBS_M, int TBS_N, int TBS_K, bool COOP>
32+
at::Tensor _f8i4bf16_shuffled(
33+
at::Tensor XQ,
34+
at::Tensor WQ,
35+
at::Tensor x_scale,
36+
at::Tensor w_scale) {
37+
// Get shape information from input tensors.
38+
int M = XQ.size(0);
39+
int K = XQ.size(1);
40+
int N = WQ.size(0);
41+
// Make sure w_scale is in proper format.
42+
TORCH_CHECK(
43+
w_scale.size(1) == 8,
44+
"Weights and scales must be prepacked with preshuffle_i4.");
45+
int num_groups = w_scale.size(0);
46+
int group_size = K / num_groups;
47+
// Allocate output.
48+
at::Tensor Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16));
49+
50+
// Define input types.
51+
using MmaType = cutlass::float_e4m3_t;
52+
using QuantType = cutlass::int4b_t;
53+
constexpr int TileShapeK = 128 * 8 / cute::sizeof_bits<MmaType>::value;
54+
55+
// A Matrix configuration.
56+
using ElementA = MmaType;
57+
using LayoutA = cutlass::layout::RowMajor;
58+
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
59+
60+
// B Matrix Configuration.
61+
using ElementB = QuantType;
62+
using LayoutB = cutlass::layout::ColumnMajor;
63+
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
64+
65+
// We need to manually swap and transpose inputs. Unclear how required this is
66+
// though.
67+
using LayoutA_Transpose =
68+
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
69+
using LayoutB_Transpose =
70+
typename cutlass::layout::LayoutTranspose<LayoutB>::type;
71+
72+
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
73+
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
74+
75+
// Define layout for shuffled weight tensor.
76+
using LayoutAtomQuant =
77+
decltype(cutlass::compute_memory_reordering_atom<MmaType>());
78+
using LayoutB_Reordered = decltype(cute::tile_to_shape(
79+
LayoutAtomQuant{}, cute::Layout<cute::Shape<int, int, int>, StrideB>{}));
80+
81+
using ElementScale = MmaType;
82+
83+
// Output Matrix configuration.
84+
using ElementC = cutlass::bfloat16_t;
85+
using LayoutC = cutlass::layout::RowMajor;
86+
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
87+
88+
// Core kernel configurations
89+
using ElementAccumulator = float;
90+
using ElementCompute = float;
91+
using ArchTag = cutlass::arch::Sm90;
92+
using OperatorClass = cutlass::arch::OpClassTensorOp;
93+
// TODO tune these shapes.
94+
using TileShape =
95+
cute::Shape<cute::Int<TB_M>, cute::Int<TB_N>, cute::Int<TileShapeK>>;
96+
using ClusterShape =
97+
cute::Shape<cute::Int<TBS_M>, cute::Int<TBS_N>, cute::Int<TBS_K>>;
98+
// TODO Should we use fast accum here?
99+
using KernelSchedule = cute::conditional_t<
100+
COOP,
101+
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
102+
cutlass::gemm::KernelTmaWarpSpecialized>;
103+
// Might be the only epilogue schedule that supports swap + transpose.
104+
using EpilogueSchedule = cute::conditional_t<
105+
COOP,
106+
cutlass::epilogue::TmaWarpSpecializedCooperative,
107+
cutlass::epilogue::TmaWarpSpecialized>;
108+
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
109+
110+
// Define EVT for rowwise scaling.
111+
using XScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
112+
0,
113+
TileShape,
114+
ElementAccumulator,
115+
ElementAccumulator,
116+
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
117+
118+
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
119+
120+
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
121+
cutlass::multiplies,
122+
ElementC, // First stage output type.
123+
ElementAccumulator, // First stage input types.
124+
cutlass::FloatRoundStyle::round_to_nearest>;
125+
126+
using EpilogueEVT =
127+
cutlass::epilogue::fusion::Sm90EVT<Compute0, XScale, Accum>;
128+
129+
using CollectiveEpilogue =
130+
typename cutlass::epilogue::collective::CollectiveBuilder<
131+
cutlass::arch::Sm90,
132+
cutlass::arch::OpClassTensorOp,
133+
TileShape,
134+
ClusterShape,
135+
EpilogueTileType,
136+
ElementAccumulator,
137+
ElementAccumulator,
138+
ElementC,
139+
typename cutlass::layout::LayoutTranspose<LayoutC>::type,
140+
AlignmentC,
141+
ElementC,
142+
typename cutlass::layout::LayoutTranspose<LayoutC>::type,
143+
AlignmentC,
144+
EpilogueSchedule,
145+
EpilogueEVT>::CollectiveOp;
146+
147+
using CollectiveMainloopShuffled =
148+
typename cutlass::gemm::collective::CollectiveBuilder<
149+
ArchTag,
150+
OperatorClass,
151+
cute::tuple<ElementB, cutlass::Array<ElementScale, 8>>,
152+
LayoutB_Reordered,
153+
AlignmentB,
154+
ElementA,
155+
LayoutA_Transpose,
156+
AlignmentA,
157+
ElementAccumulator,
158+
TileShape,
159+
ClusterShape,
160+
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
161+
sizeof(typename CollectiveEpilogue::SharedStorage))>,
162+
KernelSchedule>::CollectiveOp;
163+
164+
using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal<
165+
cute::Shape<int, int, int, int>,
166+
CollectiveMainloopShuffled,
167+
CollectiveEpilogue>;
168+
169+
using GemmShuffled =
170+
cutlass::gemm::device::GemmUniversalAdapter<GemmKernelShuffled>;
171+
172+
using StrideC = typename GemmKernelShuffled::StrideC;
173+
174+
/// Initialization
175+
auto shape_B = cute::make_shape(N, K, 1);
176+
StrideA stride_A =
177+
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, 1));
178+
StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
179+
StrideC stride_C =
180+
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(N, M, 1));
181+
LayoutB_Reordered layout_B_reordered =
182+
cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
183+
using StrideS = typename CollectiveMainloopShuffled::StrideScale;
184+
StrideS stride_S = cutlass::make_cute_packed_stride(
185+
StrideS{}, cute::make_shape(N, num_groups, 1));
186+
187+
// Define Gemm arguments.
188+
typename GemmShuffled::Arguments arguments{
189+
cutlass::gemm::GemmUniversalMode::kGemm,
190+
{N, M, K, 1},
191+
{reinterpret_cast<ElementB*>(WQ.data_ptr()),
192+
layout_B_reordered,
193+
reinterpret_cast<ElementA*>(XQ.data_ptr()),
194+
stride_A,
195+
reinterpret_cast<cutlass::Array<ElementScale, 8>*>(w_scale.data_ptr()),
196+
stride_S,
197+
group_size},
198+
{{},
199+
reinterpret_cast<ElementC*>(Y.data_ptr()),
200+
stride_C,
201+
reinterpret_cast<ElementC*>(Y.data_ptr()),
202+
stride_C}};
203+
204+
arguments.epilogue.thread = {
205+
{reinterpret_cast<ElementAccumulator*>(x_scale.data_ptr())}, // x_scale
206+
{}, // Accumulator
207+
{}, // Multiplies
208+
};
209+
210+
// Launch the workload.
211+
GemmShuffled gemm;
212+
213+
// Using the arguments, query for extra workspace required for matrix
214+
// multiplication computation
215+
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
216+
217+
// Allocate workspace memory
218+
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
219+
220+
// Check the problem size is supported or not
221+
cutlass::Status status = gemm.can_implement(arguments);
222+
if (status != cutlass::Status::kSuccess) {
223+
throw std::runtime_error("cutlass cannot implement");
224+
}
225+
226+
// Initialize CUTLASS kernel with arguments and workspace pointer
227+
status = gemm.initialize(arguments, workspace.get());
228+
if (status != cutlass::Status::kSuccess) {
229+
throw std::runtime_error("cutlass cannot initialize");
230+
}
231+
232+
status = gemm(at::cuda::getCurrentCUDAStream());
233+
234+
if (status != cutlass::Status::kSuccess) {
235+
throw std::runtime_error(
236+
std::string("cutlass cannot run") +
237+
cutlass::cutlassGetStatusString(status));
238+
}
239+
C10_CUDA_KERNEL_LAUNCH_CHECK();
240+
241+
return Y;
242+
}
243+
244+
at::Tensor f8i4bf16_shuffled(
245+
at::Tensor XQ,
246+
at::Tensor WQ,
247+
at::Tensor x_scale,
248+
at::Tensor w_scale) {
249+
int M = XQ.size(0);
250+
int K = XQ.size(1);
251+
int N = WQ.size(0);
252+
// Use shape heuristics to dispatch to optimized kernel configuration.
253+
if (M <= 16) {
254+
return _f8i4bf16_shuffled<64, 16, 2, 1, 1, false>(XQ, WQ, x_scale, w_scale);
255+
} else if (M <= 32) {
256+
return _f8i4bf16_shuffled<64, 32, 2, 1, 1, false>(XQ, WQ, x_scale, w_scale);
257+
} else if (M <= 64) {
258+
return _f8i4bf16_shuffled<64, 64, 2, 1, 1, false>(XQ, WQ, x_scale, w_scale);
259+
} else if (M <= 128) {
260+
return _f8i4bf16_shuffled<64, 128, 2, 1, 1, false>(
261+
XQ, WQ, x_scale, w_scale);
262+
} else if (M <= 256) {
263+
if (N <= 4096) {
264+
return _f8i4bf16_shuffled<64, 128, 2, 1, 1, false>(
265+
XQ, WQ, x_scale, w_scale);
266+
} else {
267+
return _f8i4bf16_shuffled<64, 256, 1, 1, 1, false>(
268+
XQ, WQ, x_scale, w_scale);
269+
}
270+
} else if (M <= 512) {
271+
if (N <= 4096) {
272+
return _f8i4bf16_shuffled<64, 256, 2, 1, 1, false>(
273+
XQ, WQ, x_scale, w_scale);
274+
} else {
275+
return _f8i4bf16_shuffled<128, 256, 2, 1, 1, true>(
276+
XQ, WQ, x_scale, w_scale);
277+
}
278+
} else if (M <= 1024) {
279+
if (N <= 1024) {
280+
return _f8i4bf16_shuffled<64, 128, 2, 1, 1, false>(
281+
XQ, WQ, x_scale, w_scale);
282+
} else if (N <= 2048) {
283+
return _f8i4bf16_shuffled<64, 256, 2, 1, 1, false>(
284+
XQ, WQ, x_scale, w_scale);
285+
} else {
286+
return _f8i4bf16_shuffled<128, 256, 2, 1, 1, true>(
287+
XQ, WQ, x_scale, w_scale);
288+
}
289+
} else {
290+
if (N <= 1024) {
291+
return _f8i4bf16_shuffled<64, 256, 2, 1, 1, false>(
292+
XQ, WQ, x_scale, w_scale);
293+
} else {
294+
return _f8i4bf16_shuffled<128, 256, 2, 1, 1, true>(
295+
XQ, WQ, x_scale, w_scale);
296+
}
297+
}
298+
}
299+
300+
#else
301+
302+
at::Tensor f8i4bf16_shuffled(
303+
at::Tensor XQ,
304+
at::Tensor WQ,
305+
at::Tensor x_scale,
306+
at::Tensor w_scale) {
307+
throw std::runtime_error(
308+
"CUDA version is older than 12.0"); // requires CUDA>=12
309+
}
310+
311+
#endif
312+
313+
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)