Skip to content

Fix run time error in ROCm platform #5147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions python/sglang/srt/layers/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import triton
import triton.language as tl

from sglang.srt.utils import is_hip

_is_hip = is_hip()

fused_softcap_autotune = triton.autotune(
configs=[
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
Expand Down Expand Up @@ -185,6 +189,9 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
assert x.shape == residual.shape and x.dtype == residual.dtype
output, mid = torch.empty_like(x), torch.empty_like(x)
bs, hidden_dim = x.shape

min_num_warps = 16 if _is_hip else 32

if autotune:
fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
Expand All @@ -193,7 +200,10 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
config = {
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
"num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
min(
triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps
),
4,
),
}

Expand Down Expand Up @@ -250,10 +260,13 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
else:
output = torch.empty_like(x)
bs, hidden_dim = x.shape

min_num_warps = 16 if _is_hip else 32

config = {
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
"num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
),
}

Expand Down
8 changes: 7 additions & 1 deletion python/sglang/srt/layers/moe/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import triton.language as tl

from sglang.srt.layers.moe.topk import fused_topk
from sglang.srt.utils import is_hip

_is_hip = is_hip()


@triton.jit
Expand Down Expand Up @@ -116,10 +119,13 @@ def fused_moe_router_impl(
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)

grid = lambda meta: (bs,)

min_num_warps = 16 if _is_hip else 32

config = {
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
"num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
),
}

Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/layers/quantization/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def input_to_float8(
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
fp8_max = finfo.max
if _is_hip:
dtype = torch.float8_e4m3fnuz
fp8_max = 224.0
scale = fp8_max / amax
x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
Expand Down
Loading