Skip to content

Commit ec50d16

Browse files
Performance Optimization: Optimized TileShape Configuration for bf16 and Mixed Formats (#3710)
Summary: Pull Request resolved: #3710 X-link: facebookresearch/FBGEMM#783 ## Performance Issue with Current BF16 and mixed TileShape Configuration The current FBGEMM bf16 kernel uses a TileShape configuration of 128x128x128, while the optimal shape for dense bf16 tensor core on H100 is m64n256k16. The current configuration leads to suboptimal performance for tensor cores and bandwidth usage, as evidenced by PTX warnings about: 'wgmma.mma_async instruction serialization due to insufficient register resources' ## Optimized TileShape (128x256x64) Implementation Modification of the TileShape configuration from 128x128x128 to 128x256x64 for large GEMM operations using a cooperative kernel, enabling optimal bandwidth and tensor cores utilization. This configuration is notably used in Flash Attention V3 and identified by Colfax-intl as the optimal configuration after empirical study for bf16 kernels. ## Benchmark Results on H100 GPU ### Benchmark configuration: PyTorch 2.6 CUDA 12.4 CPU: AMD EPYC GPU: NVIDIA H100 Benchmarks are configured with 30 kernel launch iterations and averaged over 25 Benchmark calculations. We used the same gemm sizes as in the Colfax benchmarks ### Benchmark #### bf16bf16bf16_grouped (G = 4, M = 2,048, N = 8,192, K = 8,192) | TileShape | TFlops | |-------------|-------- | | 128-128-128 | 606 | | 128-256- 64 | 790 | #### bf16i4bf16_rowwise_batched (B = 4, M = 2,048, N = 8,192, K = 8,192) | TileShape | TFlops bf16*| TFlops fp16*| TFlops float*| |-------------|-------------|-------------|------------- | | 128-128-128 | 354 | 341 | 383 | | 128-256- 64 | 704 | 727 | 763 | #### bf16i4bf16_rowwise (M=N=K = 8,192) | TileShape | TFlops bf16*| TFlops fp16*| TFlops float*| |-------------|-------------|-------------|------------- | | 128-128-128 | 349 | 351 | 381 | | 128-256- 64 | 652 | 663 | 693 | #### f8i4bf16_rowwise (M=N=K = 8,192) | TileShape | TFlops bf16*| TFlops fp16*| TFlops float*| |-------------|-------------|-------------|------------- | | 128-128-128 | 407 | 542 | 606 | | 128-256- 64 | 921 | 942 | 1088 | *WEIGHT_SCALE_DTYPE ## Technical Implementation Modified TileShape from 128-128-128 to 128-256-64 for: - bf16bf16bf16_grouped - bf16i4bf16_rowwise_batched - bf16i4bf16_rowwise - f8i4bf16_rowwise Added cooperative kernel by default for: - bf16i4bf16_rowwise_batched - bf16i4bf16_rowwise - f8i4bf16_rowwise The modifications only affect large mode and Default kernels where N > 128. These changes were made by modifying the minimum necessary code while respecting existing coding practices in FBGEMM. ## Test Coverage ### Unit Tests Results The unit tests in fbgemm_gpu/experimental/gen_ai/test/quantize have been verified for the modified kernels. jiawenliu64 jwfromm Hello! I wanted to share this contribution to FBGEMM. While this is my first PR, I hope these changes could be useful for this great project. I'd welcome any feedback if you have time to take a look. Thank you! Pull Request resolved: #3591 Reviewed By: jianyuh Differential Revision: D68609243 Pulled By: jiawenliu64
1 parent 853e97c commit ec50d16

File tree

4 files changed

+57
-42
lines changed

4 files changed

+57
-42
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -138,20 +138,20 @@ __global__ void set_dynamic_kernel_args_kernel(
138138
GroupedGemmBF16Args::ProblemShape::UnderlyingProblemShape*>(
139139
problem_shape_buf);
140140
// Pass dummy configs to get Stride structure
141-
GroupedGemmBF16Args::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
141+
GroupedGemmBF16Args::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
142142
StrideInputA* stride_input_A_ptr = reinterpret_cast<
143143
GroupedGemmBF16Args::
144-
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
144+
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
145145
StrideInputA*>(stride_buf);
146-
GroupedGemmBF16Args::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
146+
GroupedGemmBF16Args::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
147147
StrideInputB* stride_input_B_ptr = reinterpret_cast<
148148
GroupedGemmBF16Args::
149-
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
149+
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
150150
StrideInputB*>(stride_buf + stride_size);
151-
GroupedGemmBF16Args::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
151+
GroupedGemmBF16Args::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
152152
StrideOutput* stride_output_ptr = reinterpret_cast<
153153
GroupedGemmBF16Args::
154-
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
154+
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
155155
StrideOutput*>(stride_buf + (stride_size * 2));
156156

157157
output_args_ptr[group_index] =
@@ -167,15 +167,15 @@ __global__ void set_dynamic_kernel_args_kernel(
167167
zero_start_index_M[group_index], N, K);
168168
stride_input_A_ptr[group_index] = cutlass::make_cute_packed_stride(
169169
typename GroupedGemmBF16Args::
170-
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputA{},
170+
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideInputA{},
171171
{zero_start_index_M[group_index], K, 1});
172172
stride_input_B_ptr[group_index] = cutlass::make_cute_packed_stride(
173173
typename GroupedGemmBF16Args::
174-
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputB{},
174+
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideInputB{},
175175
{N, K, 1});
176176
stride_output_ptr[group_index] = cutlass::make_cute_packed_stride(
177177
typename GroupedGemmBF16Args::
178-
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideOutput{},
178+
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideOutput{},
179179
{zero_start_index_M[group_index], N, 1});
180180
}
181181
}
@@ -212,20 +212,20 @@ __global__ void set_static_kernel_args_kernel(
212212
GroupedGemmBF16Args::ProblemShape::UnderlyingProblemShape*>(
213213
problem_shape_buf);
214214
// Pass dummy configs to get Stride structure
215-
GroupedGemmBF16Args::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
215+
GroupedGemmBF16Args::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
216216
StrideInputA* stride_input_A_ptr = reinterpret_cast<
217217
GroupedGemmBF16Args::
218-
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
218+
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
219219
StrideInputA*>(stride_buf);
220-
GroupedGemmBF16Args::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
220+
GroupedGemmBF16Args::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
221221
StrideInputB* stride_input_B_ptr = reinterpret_cast<
222222
GroupedGemmBF16Args::
223-
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
223+
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
224224
StrideInputB*>(stride_buf + stride_size);
225-
GroupedGemmBF16Args::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
225+
GroupedGemmBF16Args::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
226226
StrideOutput* stride_output_ptr = reinterpret_cast<
227227
GroupedGemmBF16Args::
228-
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
228+
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
229229
StrideOutput*>(stride_buf + (stride_size * 2));
230230

231231
output_args_ptr[group_index] = reinterpret_cast<int64_t>(output_data);
@@ -237,15 +237,15 @@ __global__ void set_static_kernel_args_kernel(
237237
GroupedGemmBF16Args::ProblemShape::UnderlyingProblemShape(M, N, K);
238238
stride_input_A_ptr[group_index] = cutlass::make_cute_packed_stride(
239239
typename GroupedGemmBF16Args::
240-
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputA{},
240+
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideInputA{},
241241
{M, K, 1});
242242
stride_input_B_ptr[group_index] = cutlass::make_cute_packed_stride(
243243
typename GroupedGemmBF16Args::
244-
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputB{},
244+
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideInputB{},
245245
{N, K, 1});
246246
stride_output_ptr[group_index] = cutlass::make_cute_packed_stride(
247247
typename GroupedGemmBF16Args::
248-
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideOutput{},
248+
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideOutput{},
249249
{M, N, 1});
250250
}
251251
}
@@ -470,10 +470,10 @@ std::vector<at::Tensor> dispatch_bf16_grouped_kernel(
470470
return bf16bf16bf16_grouped_impl<64, 128, 128, 2, 1, 1, true>(
471471
x_group, w_group, output_tensor, zero_start_index_M);
472472
} else if (kernel == KernelMode::Large) {
473-
return bf16bf16bf16_grouped_impl<128, 128, 128, 2, 1, 1, true>(
473+
return bf16bf16bf16_grouped_impl<128, 256, 64, 2, 1, 1, false>(
474474
x_group, w_group, output_tensor, zero_start_index_M);
475475
} else {
476-
return bf16bf16bf16_grouped_impl<128, 128, 128, 1, 2, 1, true>(
476+
return bf16bf16bf16_grouped_impl<128, 256, 64, 2, 1, 1, false>(
477477
x_group, w_group, output_tensor, zero_start_index_M);
478478
}
479479
}

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise.cu

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,18 @@ at::Tensor bf16i4bf16_rowwise_impl(
9898
cute::Int<TBS_K>>; // Shape of the
9999
// threadblocks in a
100100
// cluster
101-
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedMixedInput;
101+
using CooperativeSchedule =
102+
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput;
102103
using PongSchedule =
103104
cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput;
104-
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
105+
using CooperativeEpilogueSchedule =
106+
cutlass::epilogue::TmaWarpSpecializedCooperative;
107+
using PongEpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
105108
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
106109
using MainLoopSchedule =
107-
cute::conditional_t<PONG, PongSchedule, DefaultSchedule>;
110+
cute::conditional_t<PONG, PongSchedule, CooperativeSchedule>;
111+
using EpilogueSchedule = cute::
112+
conditional_t<PONG, PongEpilogueSchedule, CooperativeEpilogueSchedule>;
108113

109114
using CollectiveEpilogue =
110115
typename cutlass::epilogue::collective::CollectiveBuilder<
@@ -231,18 +236,18 @@ at::Tensor dispatch_bf16i4bf16_rowwise_kernel(
231236
} else if (kernel == KernelMode::Large) {
232237
return bf16i4bf16_rowwise_impl<
233238
128,
234-
128,
235-
128,
239+
256,
240+
64,
236241
2,
237242
1,
238243
1,
239-
true,
244+
false,
240245
WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp);
241246
} else {
242247
return bf16i4bf16_rowwise_impl<
243248
128,
244-
128,
245-
128,
249+
256,
250+
64,
246251
2,
247252
1,
248253
1,

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,18 @@ at::Tensor bf16i4bf16_rowwise_batched_impl(
102102
cute::Int<TBS_K>>; // Shape of the
103103
// threadblocks in a
104104
// cluster
105-
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedMixedInput;
105+
using CooperativeSchedule =
106+
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput;
106107
using PongSchedule =
107108
cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput;
108-
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
109+
using CooperativeEpilogueSchedule =
110+
cutlass::epilogue::TmaWarpSpecializedCooperative;
111+
using PongEpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
109112
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
110113
using MainLoopSchedule =
111-
cute::conditional_t<PONG, PongSchedule, DefaultSchedule>;
114+
cute::conditional_t<PONG, PongSchedule, CooperativeSchedule>;
115+
using EpilogueSchedule = cute::
116+
conditional_t<PONG, PongEpilogueSchedule, CooperativeEpilogueSchedule>;
112117

113118
using CollectiveEpilogue =
114119
typename cutlass::epilogue::collective::CollectiveBuilder<
@@ -235,17 +240,17 @@ at::Tensor dispatch_bf16i4bf16_rowwise_batched_kernel(
235240
} else if (kernel == KernelMode::Large) {
236241
return bf16i4bf16_rowwise_batched_impl<
237242
128,
238-
128,
243+
256,
239244
64,
240245
2,
241246
1,
242247
1,
243-
true,
248+
false,
244249
WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp);
245250
} else {
246251
return bf16i4bf16_rowwise_batched_impl<
247252
128,
248-
128,
253+
256,
249254
64,
250255
2,
251256
1,

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8i4bf16_rowwise.cu

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,18 @@ at::Tensor f8i4bf16_rowwise_impl(
9292
cute::Int<TBS_K>>; // Shape of the
9393
// threadblocks in a
9494
// cluster
95-
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedMixedInput;
95+
using CooperativeSchedule =
96+
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput;
9697
using PongSchedule =
9798
cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput;
98-
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
99+
using CooperativeEpilogueSchedule =
100+
cutlass::epilogue::TmaWarpSpecializedCooperative;
101+
using PongEpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
99102
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
100103
using MainLoopSchedule =
101-
cute::conditional_t<PONG, PongSchedule, DefaultSchedule>;
104+
cute::conditional_t<PONG, PongSchedule, CooperativeSchedule>;
105+
using EpilogueSchedule = cute::
106+
conditional_t<PONG, PongEpilogueSchedule, CooperativeEpilogueSchedule>;
102107

103108
// Implement rowwise scaling epilogue for x
104109
using XScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
@@ -254,19 +259,19 @@ at::Tensor dispatch_f8i4bf16_rowwise_kernel(
254259
} else if (kernel == KernelMode::Large) {
255260
return f8i4bf16_rowwise_impl<
256261
128,
257-
128,
258-
128,
262+
256,
263+
64,
259264
2,
260265
1,
261266
1,
262-
true,
267+
false,
263268
InputDType,
264269
WEIGHT_SCALE_DTYPE>(XQ, WQ, x_scale, w_scale, w_zp);
265270
} else {
266271
return f8i4bf16_rowwise_impl<
267272
128,
268-
128,
269-
128,
273+
256,
274+
64,
270275
2,
271276
1,
272277
1,

0 commit comments

Comments
 (0)