Skip to content

Commit 8d1a148

Browse files
kkHuang-amdwunhuang
authored andcommitted
Fix ci test "test_eval_fp8_accuracy" failed (sgl-project#5185)
Co-authored-by: wunhuang <[email protected]>
1 parent cfe4aa5 commit 8d1a148

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

python/sglang/srt/layers/quantization/fp8_utils.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,19 @@ def apply_fp8_linear(
243243
if _is_cuda:
244244
qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
245245
else:
246-
qinput, x_scale = ops.scaled_fp8_quant(
247-
input_2d, input_scale, use_per_token_if_dynamic=use_per_token_if_dynamic
248-
)
246+
# TODO(kkhuang): temporarily enforce per-tensor activation scaling if weight is per-tensor scaling
247+
# final solution should be: 1. add support to per-tensor activation scaling.
248+
# 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
249+
if _is_hip and weight_scale.numel() == 1:
250+
qinput, x_scale = ops.scaled_fp8_quant(
251+
input_2d,
252+
input_scale,
253+
use_per_token_if_dynamic=use_per_token_if_dynamic,
254+
)
255+
else:
256+
qinput, x_scale = per_token_group_quant_fp8(
257+
input_2d, group_size=input_2d.shape[1]
258+
)
249259

250260
if cutlass_fp8_supported:
251261
try:

0 commit comments

Comments
 (0)