File tree Expand file tree Collapse file tree 1 file changed +13
-3
lines changed
python/sglang/srt/layers/quantization Expand file tree Collapse file tree 1 file changed +13
-3
lines changed Original file line number Diff line number Diff line change @@ -243,9 +243,19 @@ def apply_fp8_linear(
243
243
if _is_cuda :
244
244
qinput , x_scale = sglang_per_token_quant_fp8 (input_2d )
245
245
else :
246
- qinput , x_scale = ops .scaled_fp8_quant (
247
- input_2d , input_scale , use_per_token_if_dynamic = use_per_token_if_dynamic
248
- )
246
+ # TODO(kkhuang): temporarily enforce per-tensor activation scaling if weight is per-tensor scaling
247
+ # final solution should be: 1. add support to per-tensor activation scaling.
248
+ # 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
249
+ if _is_hip and weight_scale .numel () == 1 :
250
+ qinput , x_scale = ops .scaled_fp8_quant (
251
+ input_2d ,
252
+ input_scale ,
253
+ use_per_token_if_dynamic = use_per_token_if_dynamic ,
254
+ )
255
+ else :
256
+ qinput , x_scale = per_token_group_quant_fp8 (
257
+ input_2d , group_size = input_2d .shape [1 ]
258
+ )
249
259
250
260
if cutlass_fp8_supported :
251
261
try :
You can’t perform that action at this time.
0 commit comments