Skip to content

Commit e7328db

Browse files
kkHuang-amdwunhuangroot
authored andcommitted
Fix run time error in ROCm platform (sgl-project#5147)
Co-authored-by: wunhuang <[email protected]> Co-authored-by: root <[email protected]>
1 parent fabe6d4 commit e7328db

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

python/sglang/srt/layers/elementwise.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
import triton
55
import triton.language as tl
66

7+
from sglang.srt.utils import is_hip
8+
9+
_is_hip = is_hip()
10+
711
fused_softcap_autotune = triton.autotune(
812
configs=[
913
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
@@ -185,6 +189,9 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
185189
assert x.shape == residual.shape and x.dtype == residual.dtype
186190
output, mid = torch.empty_like(x), torch.empty_like(x)
187191
bs, hidden_dim = x.shape
192+
193+
min_num_warps = 16 if _is_hip else 32
194+
188195
if autotune:
189196
fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
190197
output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
@@ -193,7 +200,10 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
193200
config = {
194201
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
195202
"num_warps": max(
196-
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
203+
min(
204+
triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps
205+
),
206+
4,
197207
),
198208
}
199209

@@ -250,10 +260,13 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
250260
else:
251261
output = torch.empty_like(x)
252262
bs, hidden_dim = x.shape
263+
264+
min_num_warps = 16 if _is_hip else 32
265+
253266
config = {
254267
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
255268
"num_warps": max(
256-
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
269+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
257270
),
258271
}
259272

python/sglang/srt/layers/moe/router.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
import triton.language as tl
66

77
from sglang.srt.layers.moe.topk import fused_topk
8+
from sglang.srt.utils import is_hip
9+
10+
_is_hip = is_hip()
811

912

1013
@triton.jit
@@ -116,10 +119,13 @@ def fused_moe_router_impl(
116119
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
117120

118121
grid = lambda meta: (bs,)
122+
123+
min_num_warps = 16 if _is_hip else 32
124+
119125
config = {
120126
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
121127
"num_warps": max(
122-
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
128+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
123129
),
124130
}
125131

python/sglang/srt/layers/quantization/fp8_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def input_to_float8(
171171
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
172172
fp8_max = finfo.max
173173
if _is_hip:
174+
dtype = torch.float8_e4m3fnuz
174175
fp8_max = 224.0
175176
scale = fp8_max / amax
176177
x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)

0 commit comments

Comments
 (0)