Skip to content

[Quant Kernel] refactored per token group quant fp8 to support int8 up-to 2x faster #4396

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Mar 24, 2025

Conversation

zcnrex
Copy link
Contributor

@zcnrex zcnrex commented Mar 13, 2025

Motivation

#2965

The new cuda kernel accelerates 1x-2x compared to triton kernel. Refactored the fp8 kernel so it works for both fp8 and int8 output quantization types.

✅ torch.int8 implementations match
✅ torch.float8_e4m3fn implementations match
per-token-group-quant-8bit-performance:
    batch_size  seq_len  group_size            dst_dtype       Triton   SGL Kernel
0            1       64         128           torch.int8    10.848000     9.152000
1            1       64         128  torch.float8_e4m3fn    10.656000     9.344000
2            1      128         128           torch.int8    14.816000    11.424000
3            1      128         128  torch.float8_e4m3fn    14.656000    11.968000
4            1      256         128           torch.int8    20.223999    13.536000
5            1      256         128  torch.float8_e4m3fn    20.160001    14.592000
6            1      512         128           torch.int8    31.840000    18.528000
7            1      512         128  torch.float8_e4m3fn    31.808000    20.544000
8            1     1024         128           torch.int8    55.103999    28.031999
9            1     1024         128  torch.float8_e4m3fn    55.071998    32.512002
10           1     2048         128           torch.int8   101.920001    48.640002
11           1     2048         128  torch.float8_e4m3fn   101.824000    59.168000
12           2       64         128           torch.int8    14.656000    11.392000
13           2       64         128  torch.float8_e4m3fn    14.624000    11.936000
14           2      128         128           torch.int8    20.223999    13.408000
15           2      128         128  torch.float8_e4m3fn    20.191999    14.528000
16           2      256         128           torch.int8    31.776000    18.816000
17           2      256         128  torch.float8_e4m3fn    31.744000    20.479999
18           2      512         128           torch.int8    55.135999    28.031999
19           2      512         128  torch.float8_e4m3fn    55.119999    32.575998
20           2     1024         128           torch.int8   101.824000    48.544001
21           2     1024         128  torch.float8_e4m3fn   101.824000    59.168000
22           2     2048         128           torch.int8   193.471998    86.719997
23           2     2048         128  torch.float8_e4m3fn   193.440005   106.720001
24           4       64         128           torch.int8    20.223999    13.408000
25           4       64         128  torch.float8_e4m3fn    20.223999    14.560000
26           4      128         128           torch.int8    31.711999    18.432001
27           4      128         128  torch.float8_e4m3fn    31.744000    20.447999
28           4      256         128           torch.int8    55.167999    28.031999
29           4      256         128  torch.float8_e4m3fn    55.103999    32.575998
30           4      512         128           torch.int8   101.792000    48.576001
31           4      512         128  torch.float8_e4m3fn   101.792000    59.119999
32           4     1024         128           torch.int8   193.376005    86.719997
33           4     1024         128  torch.float8_e4m3fn   193.471998   106.784001
34           4     2048         128           torch.int8   375.871986   160.735995
35           4     2048         128  torch.float8_e4m3fn   375.936002   200.992003
36           8       64         128           torch.int8    31.744000    18.400000
37           8       64         128  torch.float8_e4m3fn    31.711999    20.447999
38           8      128         128           torch.int8    55.103999    27.968001
39           8      128         128  torch.float8_e4m3fn    55.071998    32.543998
40           8      256         128           torch.int8   101.792000    48.640002
41           8      256         128  torch.float8_e4m3fn   101.952001    59.103999
42           8      512         128           torch.int8   193.535998    86.783998
43           8      512         128  torch.float8_e4m3fn   193.504006   106.720001
44           8     1024         128           torch.int8   375.968009   160.607994
45           8     1024         128  torch.float8_e4m3fn   375.903994   201.087996
46           8     2048         128           torch.int8   740.447998   308.544010
47           8     2048         128  torch.float8_e4m3fn   740.736008   389.472008
48          16       64         128           torch.int8    55.167999    28.031999
49          16       64         128  torch.float8_e4m3fn    55.135999    32.575998
50          16      128         128           torch.int8   101.952001    48.672002
51          16      128         128  torch.float8_e4m3fn   101.952001    59.200000
52          16      256         128           torch.int8   193.440005    86.687997
53          16      256         128  torch.float8_e4m3fn   193.471998   106.688000
54          16      512         128           torch.int8   375.903994   160.768002
55          16      512         128  torch.float8_e4m3fn   375.871986   200.992003
56          16     1024         128           torch.int8   740.800023   308.447987
57          16     1024         128  torch.float8_e4m3fn   740.544021   389.616013
58          16     2048         128           torch.int8  1470.399976   604.416013
59          16     2048         128  torch.float8_e4m3fn  1470.655918   767.328024
60          32       64         128           torch.int8   101.984002    48.672002
61          32       64         128  torch.float8_e4m3fn   101.952001    59.200000
62          32      128         128           torch.int8   193.120003    86.719997
63          32      128         128  torch.float8_e4m3fn   193.471998   106.656000
64          32      256         128           torch.int8   375.903994   160.735995
65          32      256         128  torch.float8_e4m3fn   375.903994   201.023996
66          32      512         128           torch.int8   740.736008   308.512002
67          32      512         128  torch.float8_e4m3fn   740.831971   389.712006
68          32     1024         128           torch.int8  1470.960021   604.543984
69          32     1024         128  torch.float8_e4m3fn  1470.128059   766.943991
70          32     2048         128           torch.int8  2928.992033  1194.767952
71          32     2048         128  torch.float8_e4m3fn  2929.183960  1521.376014
72          64       64         128           torch.int8   193.535998    86.687997
73          64       64         128  torch.float8_e4m3fn   193.455994   106.656000
74          64      128         128           torch.int8   375.936002   160.704002
75          64      128         128  torch.float8_e4m3fn   375.936002   201.040000
76          64      256         128           torch.int8   740.863979   308.447987
77          64      256         128  torch.float8_e4m3fn   740.895987   389.663994
78          64      512         128           torch.int8  1470.095992   604.351997
79          64      512         128  torch.float8_e4m3fn  1470.672011   767.072022
80          64     1024         128           torch.int8  2928.800106  1194.975972
81          64     1024         128  torch.float8_e4m3fn  2928.832054  1521.183968
82          64     2048         128           torch.int8  5845.711708  2376.575947
83          64     2048         128  torch.float8_e4m3fn  5845.935822  3028.848171

int8 precision issue: some cells diff by 1.

Combination triton # 1 & cuda # 1 with disabled "-use_fast_math" nvcc flag

triton # cuda # % diff
1 1 0.0001373291015625
2 1 0.4912109375
1 2 0.492462158203125
2 2 0.0035552978515625
1 3 0.0001068115234375
2 3 0.4933319091796875

Combination triton # 1 & cuda # 1 with enabled "-use_fast_math" nvcc flag

triton # cuda # % diff
1 1 0.00018310546875
2 1 0.49603271484375
2 2 ✅0

Tried 2 triton conversions

# 1
y_q = y * int8_max / _absmax
y_q = tl.extra.cuda.libdevice.round(y_q).to(tl.int8)

# 2
y_q = tl.clamp(y / y_s, int8_min, int8_max).to(y_q_ptr.dtype.element_ty)

Tried 3 cuda conversions

// 1
uint32_t dst;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(q_val));
group_output[i * vec_size + j] = reinterpret_cast<const int8_t&>(dst);

// 2
group_output[i * vec_size + j] = int8_t(q_val);

// 3 trying this ROCM implementation
float dst = std::nearbyint(q_val);
dst = std::clamp(dst, int8_min, int8_max);
group_output[i * vec_size + j] = static_cast<int8_t>(dst);

Modifications

Checklist

@zcnrex zcnrex force-pushed the token-group-int8 branch from 68c70e8 to 6f92123 Compare March 13, 2025 20:10
@zcnrex zcnrex changed the title [Quant Kernel] Per token group quant int8 [Quant Kernel] Per token group quant int8 up-to 2x faster Mar 13, 2025
@zcnrex zcnrex requested review from Ying1123 and HaiShaw as code owners March 14, 2025 06:02
@zcnrex zcnrex changed the title [Quant Kernel] Per token group quant int8 up-to 2x faster [Quant Kernel] refactored per token group quant fp8 to support int8 up-to 2x faster Mar 14, 2025
@zcnrex zcnrex force-pushed the token-group-int8 branch from a2cdb9f to 9810175 Compare March 14, 2025 15:58
@zcnrex zcnrex force-pushed the token-group-int8 branch from 9810175 to acf67a1 Compare March 14, 2025 16:54
@zcnrex zcnrex force-pushed the token-group-int8 branch from 08af2c0 to 1614cc4 Compare March 17, 2025 16:02
@zcnrex zcnrex force-pushed the token-group-int8 branch from 1614cc4 to 33b9f93 Compare March 19, 2025 16:06
@zcnrex zcnrex force-pushed the token-group-int8 branch 2 times, most recently from 6d0ba94 to d0cec50 Compare March 23, 2025 06:11
@zcnrex zcnrex force-pushed the token-group-int8 branch from d0cec50 to 7f32d00 Compare March 23, 2025 06:42
@zhyncs zhyncs merged commit 65c24c2 into sgl-project:main Mar 24, 2025
2 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants