Skip to content

Add mHC fused kernels and tests#38

Merged
hannahli-nv merged 3 commits intoNVIDIA:mainfrom
Edward-lyz:main
Jan 26, 2026
Merged

Add mHC fused kernels and tests#38
hannahli-nv merged 3 commits intoNVIDIA:mainfrom
Edward-lyz:main

Conversation

@Edward-lyz
Copy link
Contributor

@Edward-lyz Edward-lyz commented Jan 21, 2026

Introduces mHC fused GEMM+RMS+scale/bias/sigmoid, residual application, and Sinkhorn normalization kernels.

Description

Added extreme optimization of the Deepseek mHC algorithm's forward inference fusion operator, including the complete computational process. Its performance implementation far exceeds the torch implementation, and the gemm+rms splitk computational operator matches the performance of the DeepGemm official implementation. The specific implementation blog can be read at this [link].(https://edward-lyz.github.io/mhc-%E7%AE%97%E6%B3%95%E5%88%86%E6%9E%90--%E7%AE%97%E5%AD%90%E5%AE%9E%E7%8E%B0/)
Below are the output results for functionality and performance:

pytest -s ops -v -k test_op_mhc                                                                                                                                                    
 ==================================================================================== test session starts ===================================================================================
platform linux -- Python 3.12.12, pytest-8.4.2, pluggy-1.6.0 -- /usr/bin/python3
cachedir: .pytest_cache
rootdir: /home/users/liyanzhen01/PUBLIC_REPO/TileGym
configfile: pytest.ini
plugins: typeguard-4.4.4, anyio-4.12.0
collecting ... Namespace(seed=0, quiet=False, params=None, dump=None, dump_dir='/tmp/', load=None, load_dir='/tmp/', load_names=['cutile'], warmup=100, rep=50, min_rep=2, initial_rep=5, mode='auto', csv=False, file='out', config='dev', fields=['median', 'rel_std'], print_matching=False, verbose=False, help=False)
seed = 0
torch = 2.9.1+cu128
device = CF-NG-BZZ2-O
collected 347 items / 339 deselected / 8 selected

ops/test_mhc.py::Test_MHC::test_op_mhc_gemm_rms_scale_bf16_precision[cutile-128-1024-4] [2026-01-21 06:48:04] [tilegym.backend.selector] [INFO] [logger.py:195] Set backend to cutile
PASSED
ops/test_mhc.py::Test_MHC::test_op_mhc_gemm_rms_scale_bf16_precision[cutile-128-1024-8] [2026-01-21 06:48:26] [tilegym.backend.selector] [INFO] [logger.py:195] Set backend to cutile
PASSED
ops/test_mhc.py::Test_MHC::test_op_mhc_gemm_rms_scale_bf16_precision[cutile-128-1000-8] [2026-01-21 06:48:46] [tilegym.backend.selector] [INFO] [logger.py:195] Set backend to cutile
PASSED
ops/test_mhc.py::Test_MHC::test_op_mhc_sinkhorn[cutile-256-4-torch.float32] [2026-01-21 06:49:07] [tilegym.backend.selector] [INFO] [logger.py:195] Set backend to cutile
PASSED
ops/test_mhc.py::Test_MHC::test_op_mhc_sinkhorn[cutile-256-8-torch.float32] [2026-01-21 06:49:07] [tilegym.backend.selector] [INFO] [logger.py:195] Set backend to cutile
PASSED
ops/test_mhc.py::Test_MHC::test_op_mhc_apply_residual[cutile-128-4-1024] [2026-01-21 06:49:07] [tilegym.backend.selector] [INFO] [logger.py:195] Set backend to cutile
PASSED
ops/test_mhc.py::Test_MHC::test_op_mhc_apply_residual[cutile-64-8-2048] [2026-01-21 06:49:07] [tilegym.backend.selector] [INFO] [logger.py:195] Set backend to cutile
PASSED
ops/test_mhc.py::Test_MHC::test_op_mhc_apply_residual[cutile-128-2-1024] [2026-01-21 06:49:08] [tilegym.backend.selector] [INFO] [logger.py:195] Set backend to cutile
PASSED

===================================================================================== warnings summary ======================================================================================
../../../../../../usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py:63
  /usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
    import pynvml  # type: ignore[import]

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================================== 8 passed, 339 deselected, 1 warning in 66.75s (0:01:06) ==================================================================
mhc-split-gemm-rms-performance-bfloat16-GBps:
        M      CuTile     PyTorch     DeepGemm
0  8192.0  5061.57394  313.423308  5064.876475
mhc-gemm-rms-scale-performance-bfloat16-GBps:
        M       CuTile     PyTorch
0  8192.0  4513.131525  388.723762
mhc-sinkhorn-performance-bfloat16-GBps:
        M     CuTile   PyTorch
0  8192.0  19.152307  2.584275
mhc-apply-residual-performance-bfloat16-GBps:
        M       CuTile     PyTorch
0  8192.0  6602.945269  337.325322

CI Configuration

config:
  build: true
  # valid options are "ops" and "benchmark"
  test: ["ops","benchmark"]

Checklist

  • Code formatted and imports sorted via repo specifications (./format.sh)
  • Documentation updated (if needed)
  • CI configuration reviewed

Introduces mHC fused GEMM+RMS+scale/bias/sigmoid, residual application, and Sinkhorn normalization kernels.
@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 21, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@Edward-lyz
Copy link
Contributor Author

@tabo @hartsock @aaronp24 @aflat
I apologize for interrupting, but this PR requires additional permissions to verify functionality/performance after executing CI (I have submitted the CLA application form)

@hannahli-nv
Copy link
Collaborator

@tabo @hartsock @aaronp24 @aflat I apologize for interrupting, but this PR requires additional permissions to verify functionality/performance after executing CI (I have submitted the CLA application form)

Hi @Edward-lyz , thank you for your contribution! We have received your CAL file. The accounts you mentioned are not members of this repo. I will start a pipeline for this PR later.

@hannahli-nv
Copy link
Collaborator

/ok to test 9ea8a5d

@Edward-lyz
Copy link
Contributor Author

Edward-lyz commented Jan 22, 2026

/ok to test 9ea8a5d

Sorry to bother you, I noticed that a test project failed due to network issues. Is it possible to run the corresponding project again instead of retesting the entire thing? @hannahli-nv

@Edward-lyz
Copy link
Contributor Author

Current runner version: '2.330.0'
Runner name: 'cfc3-l-amd-g-rtxpro6000-l-1-s8ld7-runner-cj5n2'
Runner group name: 'nv-gpu-amd64-rtxpro6000-1gpu'
Machine name: 'cfc3-l-amd-g-rtxpro6000-l-1-s8ld7-runner-cj5n2'
NVIDIA Managed Runner
GITHUB_TOKEN Permissions
Secret source: Actions
Prepare workflow directory
Prepare all required actions
Getting action download info
Download action repository 'actions/checkout@v4' (SHA:34e114876b0b11c390a56381ad16ebd13914f8d5)
Download action repository 'dawidd6/action-download-artifact@v3' (SHA:09f2f74827fd3a8607589e5ad7f9398816f540fe)
Warning: Failed to download action 'https://api.github.com/repos/dawidd6/action-download-artifact/tarball/09f2f74827fd3a8607589e5ad7f9398816f540fe'. Error: The request was canceled due to the configured HttpClient.Timeout of 100 seconds elapsing. 
Warning: Back off 27.617 seconds before retry.
Warning: Failed to download action 'https://api.github.com/repos/dawidd6/action-download-artifact/tarball/09f2f74827fd3a8607589e5ad7f9398816f540fe'. Error: The request was canceled due to the configured HttpClient.Timeout of 100 seconds elapsing. 
Warning: Back off 18.652 seconds before retry.
Error: Action 'https://api.github.com/repos/dawidd6/action-download-artifact/tarball/09f2f74827fd3a8607589e5ad7f9398816f540fe' download has timed out. Error: The request was canceled due to the configured HttpClient.Timeout of 100 seconds elapsing. 

This error has appeared again on the second run. Maybe I should remove the benchmark tests? Could you please take another look at why this error is occurring? Thank you. @hannahli-nv

@Edward-lyz
Copy link
Contributor Author

I observed that all CI tests have passed, can the code be merged? Please take a look here. @hannahli-nv

@hannahli-nv
Copy link
Collaborator

I observed that all CI tests have passed, can the code be merged? Please take a look here. @hannahli-nv

Thank you for the reminder. We will take a look at it later today.



@dispatch(
"mhc_gemm_rms_scale",
Copy link
Collaborator

@hannahli-nv hannahli-nv Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your nice work! Overall LGTM.
Would you mind moving these mhc-related functions to the "NN Operations" category instead of "Linear Algebra Operations" to keep it consistent with the classification in init.py?

@hannahli-nv
Copy link
Collaborator

/ok to test add84e3

@hannahli-nv hannahli-nv merged commit 680497a into NVIDIA:main Jan 26, 2026
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants