Skip to content

[perf] introduce deep gemm group_gemm_masked as bmm #5432

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 108 additions & 4 deletions python/sglang/srt/layers/quantization/fp8_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
fp8_min = -fp8_max

_enable_jit_deepgemm = False
_enable_jit_deepgemm_bmm = False
if _is_cuda:
import deep_gemm
from sgl_kernel import (
Expand All @@ -53,10 +54,11 @@
)

sm_version = get_device_sm()
if sm_version == 90 and get_bool_env_var(
"SGL_ENABLE_JIT_DEEPGEMM", default="false"
):
_enable_jit_deepgemm = True
if sm_version == 90:
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="false"):
_enable_jit_deepgemm = True
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM_BMM", default="false"):
_enable_jit_deepgemm_bmm = True


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -919,6 +921,108 @@ def per_tensor_quant_mla_fp8(
return x_q, x_s


@triton.jit
def _per_token_group_quant_mla_deep_gemm_masked_fp8(
y_ptr,
y_q_ptr,
y_s_ptr,
masked_m_ptr,
group_size,
y_stride_b,
y_stride_t,
y_q_stride_b,
y_q_stride_t,
y_s_stride_b,
y_s_stride_g,
eps,
fp8_min,
fp8_max,
NUM_GROUP: tl.constexpr,
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor for deep_gemm grouped_gemm_masked.
This function converts the tensor values into float8 values.
y and y_q: (b, t, k)
y_s: (b, k//group_size, t)
"""
t_id = tl.program_id(0)
b_id = tl.program_id(1)

y_ptr += b_id * y_stride_b + t_id * y_stride_t
y_q_ptr += b_id * y_q_stride_b + t_id * y_q_stride_t
y_s_ptr += b_id * y_s_stride_b + t_id

if t_id == 0:
tl.store(masked_m_ptr + b_id, tl.num_programs(0))

cols = tl.arange(0, BLOCK) # group_size <= BLOCK
mask = cols < group_size

for gid in range(NUM_GROUP):
y = tl.load(y_ptr + gid * group_size + cols, mask=mask, other=0.0).to(
tl.float32
)
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)

tl.store(y_q_ptr + gid * group_size + cols, y_q, mask=mask)
tl.store(y_s_ptr + gid * y_s_stride_g, y_s)


def per_tensor_quant_mla_deep_gemm_masked_fp8(
x: torch.Tensor,
group_size: int = 128,
eps: float = 1e-12,
dtype: torch.dtype = torch.float8_e4m3fn,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
This function quantizes input values to float8 values with per-token-group-quantization
for deep_gemm grouped_gemm_masked and specialized for mla absorbed case.
"""
assert x.dim() == 3, "`x` is not a 3d-tensor"

finfo = torch.finfo(dtype)
fp8_max = finfo.max
if _is_hip:
dtype = torch.float8_e4m3fnuz
fp8_max = 224.0

b, m, k = x.shape
aligned_m = (m + 255) // 256 * 256 # 256 is the max block_m of the gemm kernel
num_tiles_k = k // group_size
assert num_tiles_k * group_size == k, f"k % {group_size} must be zero"

x_q = x.new_empty((b, aligned_m, k), dtype=dtype)
x_s = x.new_empty((b, num_tiles_k, aligned_m), dtype=torch.float32)
masked_m = x.new_empty((b,), dtype=torch.int32)

BLOCK_SIZE = triton.next_power_of_2(group_size)
grid = (m, b)

_per_token_group_quant_mla_deep_gemm_masked_fp8[grid](
x,
x_q,
x_s,
masked_m,
group_size,
x.stride(0),
x.stride(1),
x_q.stride(0),
x_q.stride(1),
x_s.stride(0),
x_s.stride(1),
eps,
-fp8_max,
fp8_max,
num_tiles_k,
BLOCK_SIZE,
)

return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m


def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
Expand Down
102 changes: 86 additions & 16 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import per_tensor_quant_mla_fp8
from sglang.srt.layers.quantization.fp8_kernel import (
_enable_jit_deepgemm_bmm,
per_tensor_quant_mla_deep_gemm_masked_fp8,
per_tensor_quant_mla_fp8,
)
from sglang.srt.layers.quantization.fp8_utils import (
block_quant_to_tensor_quant,
channel_quant_to_tensor_quant,
Expand All @@ -82,6 +86,7 @@
_is_cuda = is_cuda()

if _is_cuda:
from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
else:
from vllm._custom_ops import awq_dequantize
Expand Down Expand Up @@ -533,6 +538,10 @@ def __init__(
self.w_vc = None
self.w_scale = None

self.w_scale_k = None
self.w_scale_v = None
self.use_deep_gemm_bmm = False

self.flashinfer_mla_disable_ragged = global_server_args_dict[
"flashinfer_mla_disable_ragged"
]
Expand Down Expand Up @@ -681,7 +690,24 @@ def forward_absorb(
)
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

if self.w_kc.dtype == torch.float8_e4m3fnuz:
if self.use_deep_gemm_bmm:
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
per_tensor_quant_mla_deep_gemm_masked_fp8(
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
)
)
q_nope_out = q_nope.new_empty(
(self.num_local_heads, aligned_m, self.kv_lora_rank)
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
(q_nope_val, q_nope_scale),
(self.w_kc, self.w_scale_k),
q_nope_out,
masked_m,
expected_m,
)
q_nope_out = q_nope_out[:, :expected_m, :]
elif self.w_kc.dtype == torch.float8_e4m3fnuz:
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
q_nope_out = torch.bmm(
q_nope.to(torch.bfloat16).transpose(0, 1),
Expand Down Expand Up @@ -712,7 +738,24 @@ def forward_absorb(
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

if self.w_vc.dtype == torch.float8_e4m3fnuz:
if self.use_deep_gemm_bmm:
attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
per_tensor_quant_mla_deep_gemm_masked_fp8(
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
)
)
attn_bmm_output = attn_output.new_empty(
(self.num_local_heads, aligned_m, self.v_head_dim)
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
(attn_output_val, attn_output_scale),
(self.w_vc, self.w_scale_v),
attn_bmm_output,
masked_m,
expected_m,
)
attn_bmm_output = attn_bmm_output[:, :expected_m, :]
elif self.w_vc.dtype == torch.float8_e4m3fnuz:
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
attn_bmm_output = torch.bmm(
attn_output.to(torch.bfloat16).transpose(0, 1),
Expand Down Expand Up @@ -1405,6 +1448,10 @@ def post_load_weights(self):
w = self_attn.kv_b_proj.weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
# This may affect the accuracy of fp8 model.
# Fix deepseek v3 blockwise bmm by using deep_gemm
use_deep_gemm_bmm = False
model_dtype = torch.get_default_dtype()

if w.dtype in (
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
Expand All @@ -1423,10 +1470,20 @@ def post_load_weights(self):
weight = w
weight_scale = self_attn.kv_b_proj.weight_scale_inv

w, scale = block_quant_to_tensor_quant(
weight, weight_scale, weight_block_size
)
self_attn.w_scale = scale
if (
_is_cuda
and _enable_jit_deepgemm_bmm
and weight_block_size[0] == 128
and weight_block_size[1] == 128
and model_dtype == torch.bfloat16
):
block_scale = weight_scale
use_deep_gemm_bmm = True
else:
w, scale = block_quant_to_tensor_quant(
weight, weight_scale, weight_block_size
)
self_attn.w_scale = scale
else:
weight = w
weight_scale = self_attn.kv_b_proj.weight_scale
Expand All @@ -1449,18 +1506,31 @@ def post_load_weights(self):
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
torch.bfloat16
)

w_kc, w_vc = w.unflatten(
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
if (
hasattr(self_attn.kv_b_proj, "weight_scale")
and self_attn.w_scale is None
):
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
if _is_hip:
self_attn.w_scale *= 2.0
if not use_deep_gemm_bmm:
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
if (
hasattr(self_attn.kv_b_proj, "weight_scale")
and self_attn.w_scale is None
):
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
if _is_hip:
self_attn.w_scale *= 2.0
else:
num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
ws_kc, ws_vc = block_scale.unflatten(
0, (-1, (num_tiles_k + num_tiles_n))
).split([num_tiles_k, num_tiles_n], dim=1)
self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous()
self_attn.w_scale_v = ws_vc.contiguous()
self_attn.w_kc = w_kc.transpose(1, 2).contiguous()
self_attn.w_vc = w_vc.contiguous()
self_attn.use_deep_gemm_bmm = True

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
Expand Down
Loading
Loading