You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Performance Optimization: Optimized TileShape Configuration for bf16 and Mixed Formats (#3710)
Summary:
Pull Request resolved: #3710
X-link: facebookresearch/FBGEMM#783
## Performance Issue with Current BF16 and mixed TileShape Configuration
The current FBGEMM bf16 kernel uses a TileShape configuration of 128x128x128,
while the optimal shape for dense bf16 tensor core on H100 is m64n256k16.
The current configuration leads to suboptimal performance for tensor cores and bandwidth usage,
as evidenced by PTX warnings about:
'wgmma.mma_async instruction serialization due to insufficient register resources'
## Optimized TileShape (128x256x64) Implementation
Modification of the TileShape configuration from 128x128x128 to 128x256x64 for large GEMM
operations using a cooperative kernel, enabling optimal bandwidth and tensor cores utilization.
This configuration is notably used in Flash Attention V3 and identified by Colfax-intl
as the optimal configuration after empirical study for bf16 kernels.
## Benchmark Results on H100 GPU
### Benchmark configuration:
PyTorch 2.6
CUDA 12.4
CPU: AMD EPYC
GPU: NVIDIA H100
Benchmarks are configured with 30 kernel launch iterations
and averaged over 25 Benchmark calculations.
We used the same gemm sizes as in the Colfax benchmarks
### Benchmark
#### bf16bf16bf16_grouped (G = 4, M = 2,048, N = 8,192, K = 8,192)
| TileShape | TFlops |
|-------------|-------- |
| 128-128-128 | 606 |
| 128-256- 64 | 790 |
#### bf16i4bf16_rowwise_batched (B = 4, M = 2,048, N = 8,192, K = 8,192)
| TileShape | TFlops bf16*| TFlops fp16*| TFlops float*|
|-------------|-------------|-------------|------------- |
| 128-128-128 | 354 | 341 | 383 |
| 128-256- 64 | 704 | 727 | 763 |
#### bf16i4bf16_rowwise (M=N=K = 8,192)
| TileShape | TFlops bf16*| TFlops fp16*| TFlops float*|
|-------------|-------------|-------------|------------- |
| 128-128-128 | 349 | 351 | 381 |
| 128-256- 64 | 652 | 663 | 693 |
#### f8i4bf16_rowwise (M=N=K = 8,192)
| TileShape | TFlops bf16*| TFlops fp16*| TFlops float*|
|-------------|-------------|-------------|------------- |
| 128-128-128 | 407 | 542 | 606 |
| 128-256- 64 | 921 | 942 | 1088 |
*WEIGHT_SCALE_DTYPE
## Technical Implementation
Modified TileShape from 128-128-128 to 128-256-64 for:
- bf16bf16bf16_grouped
- bf16i4bf16_rowwise_batched
- bf16i4bf16_rowwise
- f8i4bf16_rowwise
Added cooperative kernel by default for:
- bf16i4bf16_rowwise_batched
- bf16i4bf16_rowwise
- f8i4bf16_rowwise
The modifications only affect large mode and Default kernels where N > 128.
These changes were made by modifying the minimum necessary code while respecting
existing coding practices in FBGEMM.
## Test Coverage
### Unit Tests Results
The unit tests in fbgemm_gpu/experimental/gen_ai/test/quantize
have been verified for the modified kernels.
jiawenliu64 jwfromm Hello! I wanted to share this contribution to FBGEMM.
While this is my first PR, I hope these changes could be useful for this great project.
I'd welcome any feedback if you have time to take a look. Thank you!
Pull Request resolved: #3591
Reviewed By: jianyuh
Differential Revision: D68609243
Pulled By: jiawenliu64
fbshipit-source-id: e6cc2a9e42f2fc7da76f5fa7189fe773a8c69e51
0 commit comments