Skip to content

[ci] fix ci test fused_moe op #5102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class TestFile:
TestFile("test_create_kvindices.py", 2),
TestFile("test_hicache.py", 60),
TestFile("test_hicache_mla.py", 90),
TestFile("test_fused_moe.py", 30),
],
"per-commit-2-gpu": [
TestFile("models/lora/test_lora_tp.py", 300),
Expand Down
64 changes: 42 additions & 22 deletions test/srt/test_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch
import torch.nn.functional as F
from tqdm import tqdm
from vllm.model_executor.layers.fused_moe import fused_moe as fused_moe_vllm

from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
Expand Down Expand Up @@ -45,20 +44,49 @@ def get_tolerance(self, dtype):
else:
return 1e-2, 1e-2 # Default values for other types

def torch_naive_moe(self, a, w1, w2, score, topk):
def torch_naive_moe(
self,
a,
w1,
w2,
score,
topk,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
for i in range(w1.shape[0]):

if w1.dtype == torch.float8_e4m3fn:
w1_orig = w1.to(a.dtype)
w2_orig = w2.to(a.dtype)

if w1_scale is not None:
w1_orig = (w1_orig * w1_scale.view(-1, 1, 1)).to(a.dtype)
if w2_scale is not None:
w2_orig = (w2_orig * w2_scale.view(-1, 1, 1)).to(a.dtype)
if a1_scale is not None:
a = (a * a1_scale).to(a.dtype)
if a2_scale is not None:
a = (a * a2_scale).to(a.dtype)
else:
w1_orig = w1
w2_orig = w2

for i in range(w1_orig.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[
i
].transpose(0, 1)
out[mask] = SiluAndMul()(
a[mask] @ w1_orig[i].transpose(0, 1)
) @ w2_orig[i].transpose(0, 1)

return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
Expand Down Expand Up @@ -98,21 +126,12 @@ def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False):
a2_scale=a2_scale,
)

vllm_output = fused_moe_vllm(
a,
w1,
w2,
score,
topk,
renormalize=False,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
torch_output = self.torch_naive_moe(
a, w1, w2, score, topk, w1_scale, w2_scale, a1_scale, a2_scale
)
torch.testing.assert_close(
sglang_output, torch_output, rtol=rtol, atol=atol
)

torch.testing.assert_close(sglang_output, vllm_output, rtol=rtol, atol=atol)

else:
a = self.create_random_cuda_tensor((m, k), dtype)
Expand All @@ -127,8 +146,8 @@ def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False):
)

def test_various_configurations(self):
m_values = [1, 33, 64, 222, 1024 * 128]
n_values = [128, 1024, 2048]
m_values = [1, 33, 64, 222]
n_values = [128, 1024]
k_values = [128, 511, 1024]
dtypes = [torch.float16, torch.bfloat16]
fp8_modes = [False, True]
Expand Down Expand Up @@ -171,6 +190,7 @@ def test_various_configurations(self):
dtype,
use_fp8_w8a8=use_fp8_w8a8,
)
torch.cuda.empty_cache()
pbar.update(1)


Expand Down
Loading