Skip to content

Commit b9acfeb

Browse files
Performance Optimization: Optimized TileShape Configuration for f8 (#3735)
Summary: Pull Request resolved: #3735 X-link: facebookresearch/FBGEMM#816 ## Performance Issue with Current F8 TileShape Configuration The current FBGEMM f8 kernel uses a TileShape configuration of 128x128x128, while the optimal shape for dense f8 tensor core on H100 is m64n256k32. The current configuration leads to suboptimal performance for tensor cores and bandwidth usage. ## Optimized TileShape (128x256x128) Implementation Modification of the TileShape configuration from 128x128x128 to 128x256x128 for large GEMM operations using a cooperative kernel, enabling optimal bandwidth and tensor cores utilization. This configuration is notably used in Flash Attention V3 for f8. ## Benchmark Results on H100 GPU ### Benchmark configuration: PyTorch 2.6 CUDA 12.4 CPU: AMD EPYC GPU: NVIDIA H100 Benchmarks are configured with 30 kernel launch iterations and averaged over 25 Benchmark calculations. We used the same gemm sizes as in the Colfax benchmarks ### Benchmark #### f8f8bf16_grouped (G = 4, M = 2,048, N = 8,192, K = 8,192) | TileShape | TFlops | |-------------|-------- | | 128-128-128 | 1244 | | 128-256-128 | 1374 | #### f8f8bf16_rowwise (M = N = K = 8,192) | TileShape | TFlops | |-------------|------- | | 128-128-128 | 1300 | | 128-256-128 | 1480 | #### f8f8bf16_tensorwise (M=N=K = 8,192) | TileShape | TFlops | |-------------|------- | | 128-128-128 | 1271 | | 128-256-128 | 1463 | ## Technical Implementation Modified TileShape from 128-128-128 to 128-256-128 for: - f8f8bf16_grouped - f8f8bf16_rowwise - f8f8bf16_tensorwise Added cooperative kernel by default for: - f8f8bf16_rowwise - f8f8bf16_tensorwise f8f8f16.cu was not modified because it was deprecated compared to f8f8bf16_tensorwise The modifications only affect large where M > 128 and N > 128 and M or N > 2,048. The matrices are divided into tiles twice as large, but with kernels using 3 SMs instead of 2. The smaller heuristics of large kernels may experience a slight reduced efficiency compared to the previous configuration. An empirical study between F8 kernel configurations and GEMM sizes could benefit FBGEMM. These changes were made by modifying the minimum necessary code while respecting existing coding practices in FBGEMM. ## Test Coverage ### Unit Tests Results The unit tests in fbgemm_gpu/experimental/gen_ai/test/quantize have been verified for the modified kernels. jiawenliu64 jwfromm Thank you! Pull Request resolved: #3617 Reviewed By: sunfish2010 Differential Revision: D68719476 Pulled By: jiawenliu64 fbshipit-source-id: 60705574aa1779e0171fea01addf8f20788c4749
1 parent f227f75 commit b9acfeb

File tree

2 files changed

+38
-23
lines changed

2 files changed

+38
-23
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_grouped.cu

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -185,17 +185,17 @@ __global__ void set_kernel_args_kernel(
185185
GroupedGemmArgs::ProblemShape::UnderlyingProblemShape*>(
186186
problem_shape_buf);
187187
// Pass dummy configs to get Stride structure
188-
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
188+
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::
189189
StrideInputA* stride_input_A_ptr = reinterpret_cast<
190-
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
190+
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::
191191
StrideInputA*>(stride_buf);
192-
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
192+
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::
193193
StrideInputB* stride_input_B_ptr = reinterpret_cast<
194-
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
194+
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::
195195
StrideInputB*>(stride_buf + stride_size);
196-
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
196+
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::
197197
StrideOutput* stride_output_ptr = reinterpret_cast<
198-
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
198+
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::
199199
StrideOutput*>(stride_buf + (stride_size * 2));
200200

201201
output_args_ptr[group_index] =
@@ -210,15 +210,15 @@ __global__ void set_kernel_args_kernel(
210210
GroupedGemmArgs::ProblemShape::UnderlyingProblemShape(M, N, K);
211211
stride_input_A_ptr[group_index] = cutlass::make_cute_packed_stride(
212212
typename GroupedGemmArgs::
213-
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideInputA{},
213+
GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::StrideInputA{},
214214
{M, K, 1});
215215
stride_input_B_ptr[group_index] = cutlass::make_cute_packed_stride(
216216
typename GroupedGemmArgs::
217-
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideInputB{},
217+
GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::StrideInputB{},
218218
{N, K, 1});
219219
stride_output_ptr[group_index] = cutlass::make_cute_packed_stride(
220220
typename GroupedGemmArgs::
221-
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideOutput{},
221+
GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::StrideOutput{},
222222
{M, N, 1});
223223
}
224224
}
@@ -263,17 +263,17 @@ __global__ void set_dynamic_kernel_args_kernel(
263263
GroupedGemmArgs::ProblemShape::UnderlyingProblemShape*>(
264264
problem_shape_buf);
265265
// Pass dummy configs to get Stride structure
266-
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
266+
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::
267267
StrideInputA* stride_input_A_ptr = reinterpret_cast<
268-
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
268+
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::
269269
StrideInputA*>(stride_buf);
270-
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
270+
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::
271271
StrideInputB* stride_input_B_ptr = reinterpret_cast<
272-
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
272+
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::
273273
StrideInputB*>(stride_buf + stride_size);
274-
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
274+
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::
275275
StrideOutput* stride_output_ptr = reinterpret_cast<
276-
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
276+
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::
277277
StrideOutput*>(stride_buf + (stride_size * 2));
278278

279279
output_args_ptr[group_index] =
@@ -289,15 +289,15 @@ __global__ void set_dynamic_kernel_args_kernel(
289289
zero_start_index_M[group_index], N, K);
290290
stride_input_A_ptr[group_index] = cutlass::make_cute_packed_stride(
291291
typename GroupedGemmArgs::
292-
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideInputA{},
292+
GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::StrideInputA{},
293293
{zero_start_index_M[group_index], K, 1});
294294
stride_input_B_ptr[group_index] = cutlass::make_cute_packed_stride(
295295
typename GroupedGemmArgs::
296-
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideInputB{},
296+
GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::StrideInputB{},
297297
{N, K, 1});
298298
stride_output_ptr[group_index] = cutlass::make_cute_packed_stride(
299299
typename GroupedGemmArgs::
300-
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideOutput{},
300+
GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::StrideOutput{},
301301
{zero_start_index_M[group_index], N, 1});
302302
}
303303
}
@@ -567,6 +567,16 @@ at::Tensor dispatch_fp8_grouped_kernel(
567567
1,
568568
1,
569569
true>(XQ, WQ, x_scale, w_scale, output, zero_start_index_M);
570+
} else if (kernel == KernelMode::Large) {
571+
return f8f8bf16_rowwise_grouped_impl<
572+
InputType,
573+
128,
574+
256,
575+
128,
576+
2,
577+
1,
578+
1,
579+
false>(XQ, WQ, x_scale, w_scale, output, zero_start_index_M);
570580
} else {
571581
return f8f8bf16_rowwise_grouped_impl<
572582
InputType,

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_tensorwise.cu

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,17 +99,22 @@ at::Tensor f8f8bf16_tensorwise_impl(
9999
KernelScheduleAuto; // Kernel to launch based on the default setting in
100100
// the Collective Builder
101101

102-
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
102+
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative;
103103
using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
104104
using FastDefaultSchedule =
105-
cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
105+
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
106106
using FastPongSchedule =
107107
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
108108
using SlowAccum = cute::conditional_t<PONG, PongSchedule, DefaultSchedule>;
109109
using FastAccum =
110110
cute::conditional_t<PONG, FastPongSchedule, FastDefaultSchedule>;
111+
using CooperativeEpilogueSchedule =
112+
cutlass::epilogue::TmaWarpSpecializedCooperative;
113+
using PongEpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
111114
using MainLoopSchedule =
112115
cute::conditional_t<FAST_ACCUM, FastAccum, SlowAccum>;
116+
using EpilogueSchedule = cute::
117+
conditional_t<PONG, PongEpilogueSchedule, CooperativeEpilogueSchedule>;
113118

114119
using Scale_ =
115120
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementComputeEpilogue>;
@@ -140,7 +145,7 @@ at::Tensor f8f8bf16_tensorwise_impl(
140145
ElementOutput,
141146
LayoutOutput,
142147
AlignmentOutput,
143-
cutlass::epilogue::TmaWarpSpecialized,
148+
EpilogueSchedule,
144149
EpilogueEVT>::CollectiveOp;
145150

146151
using CollectiveMainloop =
@@ -239,10 +244,10 @@ at::Tensor f8f8bf16_tensorwise(
239244
return f8f8bf16_tensorwise_impl<64, 128, 128, 2, 1, 1, true, true>(
240245
XQ, WQ, scale);
241246
} else if (kernel == KernelMode::Large) {
242-
return f8f8bf16_tensorwise_impl<128, 128, 128, 2, 1, 1, true, true>(
247+
return f8f8bf16_tensorwise_impl<128, 256, 128, 2, 1, 1, false, true>(
243248
XQ, WQ, scale);
244249
} else {
245-
return f8f8bf16_tensorwise_impl<128, 128, 128, 1, 2, 1, false, true>(
250+
return f8f8bf16_tensorwise_impl<128, 128, 128, 1, 2, 1, true, true>(
246251
XQ, WQ, scale);
247252
}
248253
}

0 commit comments

Comments
 (0)