Skip to content

Conversation

@danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Dec 29, 2025

Summary

  • integrate new cuda kernel for per group scale conversion to blocked format into mxfp8 grouped mm autograd func

Tests

  • pytest test/prototype/moe_training/test_scaled_grouped_mm.py

Benchmarks for autograd func fwd + bwd (dynamic mxfp8 quantization + mxfp8 grouped GEMM)

Before:

M,N,K,G                  recipe                  bf16_fwd_bwd_us    scaled_fwd_bwd_us  scaled_fwd_bwd_speedup      bf16_fwd_us    scaled_fwd_us  scaled_fwd_speedup
-----------------------  --------------------  -----------------  -------------------  ------------------------  -------------  ---------------  --------------------
(128000, 8192, 5120, 1)  MoEScalingType.MXFP8           32520.3              19518.3   1.666x                         10522.7           5773.22  1.823x
(128000, 8192, 5120, 2)  MoEScalingType.MXFP8           32451.6              19549.1   1.66x                          10116.2           5743.09  1.761x
(128000, 8192, 5120, 4)  MoEScalingType.MXFP8           32233.4              19376.1   1.664x                         11167.8           5711.9   1.955x
(128000, 8192, 5120, 8)  MoEScalingType.MXFP8           31674.5              19295.2   1.642x                         10106.9           5474.24  1.846x
(128000, 1536, 5120, 1)  MoEScalingType.MXFP8            6416.42              6939.65  0.925x                          1834.88          2022.42  0.907x
(128000, 1536, 5120, 2)  MoEScalingType.MXFP8            6320.22              6224.9   1.015x                          1658.82          1814.56  0.914x
(128000, 1536, 5120, 4)  MoEScalingType.MXFP8            5755.81              6218.88  0.926x                          2026.56          1820.61  1.113x
(128000, 1536, 5120, 8)  MoEScalingType.MXFP8            6302.72              5334.99  1.181x                          1810.38          1610.78  1.124x
(128000, 2048, 7168, 1)  MoEScalingType.MXFP8           11666.4               9886.19  1.18x                           3840.4           2917.47  1.316x
(128000, 2048, 7168, 2)  MoEScalingType.MXFP8           11647.9               9666.58  1.205x                          3779.55          2842.21  1.33x
(128000, 2048, 7168, 4)  MoEScalingType.MXFP8           11309.1               8625.25  1.311x                          3816.94          2634.32  1.449x
(128000, 2048, 7168, 8)  MoEScalingType.MXFP8           11490.9               8418.27  1.365x                          3389.47          2543.52  1.333x

After:

M,N,K,G                  recipe                  bf16_fwd_bwd_us    scaled_fwd_bwd_us  scaled_fwd_bwd_speedup      bf16_fwd_us    scaled_fwd_us  scaled_fwd_speedup
-----------------------  --------------------  -----------------  -------------------  ------------------------  -------------  ---------------  --------------------
(128000, 8192, 5120, 1)  MoEScalingType.MXFP8           32323.8              18576.4   1.74x                          10221.6           5311.42  1.924x
(128000, 8192, 5120, 2)  MoEScalingType.MXFP8           31286.3              18587.6   1.683x                         10188.8           5366.29  1.899x
(128000, 8192, 5120, 4)  MoEScalingType.MXFP8           35184.9              19145.8   1.838x                         10301.5           5503.01  1.872x
(128000, 8192, 5120, 8)  MoEScalingType.MXFP8           33045.6              19265.6   1.715x                         10010.7           5363.28  1.867x
(128000, 1536, 5120, 1)  MoEScalingType.MXFP8            6279.46              5737.41  1.094x                          1532.86          1470.59  1.042x
(128000, 1536, 5120, 2)  MoEScalingType.MXFP8            6260.86              5632     1.112x                          2027.46          1480.7   1.369x
(128000, 1536, 5120, 4)  MoEScalingType.MXFP8            6371.36              5297.22  1.203x                          1972.29          1498.18  1.316x
(128000, 1536, 5120, 8)  MoEScalingType.MXFP8            6428.58              5321.2   1.208x                          1761.17          1463.33  1.204x
(128000, 2048, 7168, 1)  MoEScalingType.MXFP8           11590.7               8782.91  1.32x                           3299.3           2417.25  1.365x
(128000, 2048, 7168, 2)  MoEScalingType.MXFP8           11312.1               8810.46  1.284x                          2870.37          2473.54  1.16x
(128000, 2048, 7168, 4)  MoEScalingType.MXFP8           11178.1               8475.62  1.319x                          3399.66          2412.51  1.409x
(128000, 2048, 7168, 8)  MoEScalingType.MXFP8           11716.6               8387.01  1.397x                          3316.77          2420.77  1.37x

stack-info: PR: #3545, branch: danielvegamyhre/stack/91
…h groups along M

stack-info: PR: #3546, branch: danielvegamyhre/stack/92
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 29, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3556

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit f1bf15c with merge base 25418fd (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 29, 2025
@danielvegamyhre danielvegamyhre added mx topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) moe labels Dec 29, 2025
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/92 branch 3 times, most recently from f32f3e1 to 7a12730 Compare December 31, 2025 01:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. moe mx topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants