Add mHC fused kernels and tests#38
Conversation
Introduces mHC fused GEMM+RMS+scale/bias/sigmoid, residual application, and Sinkhorn normalization kernels.
Hi @Edward-lyz , thank you for your contribution! We have received your CAL file. The accounts you mentioned are not members of this repo. I will start a pipeline for this PR later. |
|
/ok to test 9ea8a5d |
Sorry to bother you, I noticed that a test project failed due to network issues. Is it possible to run the corresponding project again instead of retesting the entire thing? @hannahli-nv |
Current runner version: '2.330.0'
Runner name: 'cfc3-l-amd-g-rtxpro6000-l-1-s8ld7-runner-cj5n2'
Runner group name: 'nv-gpu-amd64-rtxpro6000-1gpu'
Machine name: 'cfc3-l-amd-g-rtxpro6000-l-1-s8ld7-runner-cj5n2'
NVIDIA Managed Runner
GITHUB_TOKEN Permissions
Secret source: Actions
Prepare workflow directory
Prepare all required actions
Getting action download info
Download action repository 'actions/checkout@v4' (SHA:34e114876b0b11c390a56381ad16ebd13914f8d5)
Download action repository 'dawidd6/action-download-artifact@v3' (SHA:09f2f74827fd3a8607589e5ad7f9398816f540fe)
Warning: Failed to download action 'https://api.github.com/repos/dawidd6/action-download-artifact/tarball/09f2f74827fd3a8607589e5ad7f9398816f540fe'. Error: The request was canceled due to the configured HttpClient.Timeout of 100 seconds elapsing.
Warning: Back off 27.617 seconds before retry.
Warning: Failed to download action 'https://api.github.com/repos/dawidd6/action-download-artifact/tarball/09f2f74827fd3a8607589e5ad7f9398816f540fe'. Error: The request was canceled due to the configured HttpClient.Timeout of 100 seconds elapsing.
Warning: Back off 18.652 seconds before retry.
Error: Action 'https://api.github.com/repos/dawidd6/action-download-artifact/tarball/09f2f74827fd3a8607589e5ad7f9398816f540fe' download has timed out. Error: The request was canceled due to the configured HttpClient.Timeout of 100 seconds elapsing. This error has appeared again on the second run. Maybe I should remove the benchmark tests? Could you please take another look at why this error is occurring? Thank you. @hannahli-nv |
|
I observed that all CI tests have passed, can the code be merged? Please take a look here. @hannahli-nv |
Thank you for the reminder. We will take a look at it later today. |
|
|
||
|
|
||
| @dispatch( | ||
| "mhc_gemm_rms_scale", |
There was a problem hiding this comment.
Thank you for your nice work! Overall LGTM.
Would you mind moving these mhc-related functions to the "NN Operations" category instead of "Linear Algebra Operations" to keep it consistent with the classification in init.py?
|
/ok to test add84e3 |
Introduces mHC fused GEMM+RMS+scale/bias/sigmoid, residual application, and Sinkhorn normalization kernels.
Description
Added extreme optimization of the Deepseek mHC algorithm's forward inference fusion operator, including the complete computational process. Its performance implementation far exceeds the torch implementation, and the gemm+rms splitk computational operator matches the performance of the DeepGemm official implementation. The specific implementation blog can be read at this [link].(https://edward-lyz.github.io/mhc-%E7%AE%97%E6%B3%95%E5%88%86%E6%9E%90--%E7%AE%97%E5%AD%90%E5%AE%9E%E7%8E%B0/)
Below are the output results for functionality and performance:
mhc-split-gemm-rms-performance-bfloat16-GBps: M CuTile PyTorch DeepGemm 0 8192.0 5061.57394 313.423308 5064.876475 mhc-gemm-rms-scale-performance-bfloat16-GBps: M CuTile PyTorch 0 8192.0 4513.131525 388.723762 mhc-sinkhorn-performance-bfloat16-GBps: M CuTile PyTorch 0 8192.0 19.152307 2.584275 mhc-apply-residual-performance-bfloat16-GBps: M CuTile PyTorch 0 8192.0 6602.945269 337.325322CI Configuration
Checklist
./format.sh)