Skip to content

Commit e6af9a4

Browse files
Knarf04valarLip
andauthored
[MXFP4] Patch fp4_utils.py rounding logic following #975 (#2249)
* [MXFP4] Patch fp4_utils.py rounding logic following #975 * Fix formatting * [MXFP4] Trim the padding for non-shuffled input --------- Co-authored-by: Lingpeng Jin <103567126+valarLip@users.noreply.github.com>
1 parent 8bd4656 commit e6af9a4

File tree

2 files changed

+108
-16
lines changed

2 files changed

+108
-16
lines changed

aiter/utility/fp4_utils.py

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -296,30 +296,63 @@ def _dynamic_mxfp4_quant_kernel_asm_layout(
296296
# S101 -> +/- 3.0
297297
# S110 -> +/- 4.0
298298
# S111 -> +/- 6.0
299+
# FP4 format constants
300+
EXP_BIAS_FP32: tl.constexpr = 127
301+
EXP_BIAS_FP4: tl.constexpr = 1
302+
EBITS_F32: tl.constexpr = 8
303+
EBITS_FP4: tl.constexpr = 2
304+
MBITS_F32: tl.constexpr = 23
305+
MBITS_FP4: tl.constexpr = 1
306+
307+
max_normal: tl.constexpr = 6
308+
min_normal: tl.constexpr = 1
309+
299310
qx = qx.to(tl.uint32, bitcast=True)
300311

301-
# Extract sign, exponents and mantissa fields from FP32
312+
# Extract sign
302313
s = qx & 0x80000000
303-
e = (qx >> 23) & 0xFF
304-
m = qx & 0x7FFFFF
314+
# Set everything to positive, will add sign back at the end
315+
qx = qx ^ s
305316

306-
E8_BIAS: tl.constexpr = 127
307-
E2_BIAS: tl.constexpr = 1
317+
qx_fp32 = qx.to(tl.float32, bitcast=True)
318+
saturate_mask = qx_fp32 >= max_normal
319+
denormal_mask = (not saturate_mask) & (qx_fp32 < min_normal)
320+
normal_mask = not (saturate_mask | denormal_mask)
308321

309322
# Denormal numbers
310-
# If exponent is less than 127, then it's a denormal number
311-
# See above, for denormal number mantissa is always 1 and we set bit 1 of mantissa
312-
adjusted_exponents = tl.core.sub(E8_BIAS, e + 1, sanitize_overflow=False)
313-
m = tl.where(e < E8_BIAS, (0x400000 | (m >> 1)) >> adjusted_exponents, m)
323+
denorm_exp: tl.constexpr = (
324+
(EXP_BIAS_FP32 - EXP_BIAS_FP4) + (MBITS_F32 - MBITS_FP4) + 1
325+
)
326+
denorm_mask_int: tl.constexpr = denorm_exp << MBITS_F32
327+
denorm_mask_float: tl.constexpr = tl.cast(denorm_mask_int, tl.float32, bitcast=True)
328+
329+
denormal_x = qx_fp32 + denorm_mask_float
330+
denormal_x = denormal_x.to(tl.uint32, bitcast=True)
331+
denormal_x -= denorm_mask_int
332+
denormal_x = denormal_x.to(tl.uint8)
333+
334+
# Normal numbers
335+
normal_x = qx
336+
# resulting mantissa is odd
337+
mant_odd = (normal_x >> (MBITS_F32 - MBITS_FP4)) & 1
338+
# update exponent, rounding bias part 1
339+
val_to_add = ((EXP_BIAS_FP4 - EXP_BIAS_FP32) << MBITS_F32) + (1 << 21) - 1
340+
normal_x += val_to_add
341+
# rounding bias part 2
342+
normal_x += mant_odd
343+
# take the bits!
344+
normal_x = normal_x >> (MBITS_F32 - MBITS_FP4)
345+
normal_x = normal_x.to(tl.uint8)
314346

315-
# For normal numbers, bias is changed from 127 to 1, and for subnormals, we keep exponent as 0.
316-
# Note: E8_BIAS - E2_BIAS = 126, so for normals we subtract that.
317-
e = tl.maximum(e, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS)
347+
# Merge results
348+
e2m1_value = tl.full(qx.type.get_block_shapes(), 0x7, dtype=tl.uint8)
349+
e2m1_value = tl.where(normal_mask, normal_x, e2m1_value)
350+
e2m1_value = tl.where(denormal_mask, denormal_x, e2m1_value)
318351

319-
# Combine sign, exponent, and mantissa, while saturating
320-
# rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
321-
e2m1_tmp = tl.minimum((((e << 2) | (m >> 21)) + 1) >> 1, 0x7)
322-
e2m1_value = ((s >> 28) | e2m1_tmp).to(tl.uint8)
352+
# add sign back
353+
sign_lp = s >> (MBITS_F32 + EBITS_F32 - MBITS_FP4 - EBITS_FP4)
354+
sign_lp = sign_lp.to(tl.uint8)
355+
e2m1_value = e2m1_value | sign_lp
323356

324357
e2m1_value = tl.reshape(e2m1_value, [BLOCK_SIZE, MXFP4_QUANT_BLOCK_SIZE // 2, 2])
325358
evens, odds = tl.split(e2m1_value)
@@ -422,6 +455,10 @@ def dynamic_mxfp4_quant(
422455
SHUFFLE=shuffle,
423456
)
424457

458+
if not shuffle:
459+
# Trim the padding if not shuffled
460+
blockscale_e8m0 = blockscale_e8m0[:M, :scaleN_valid].contiguous()
461+
425462
return (x_fp4.view(dtypes.fp4x2), blockscale_e8m0.view(dtypes.fp8_e8m0))
426463

427464

op_tests/triton_tests/quant/test_quant_mxfp4.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
import pytest
66

77
from aiter.ops.triton.quant import dynamic_mxfp4_quant
8+
from aiter.utility.fp4_utils import (
9+
dynamic_mxfp4_quant as fp4_utils_dynamic_mxfp4_quant,
10+
)
811

912
DEBUG_MODE = False
1013

@@ -202,3 +205,55 @@ def test_dynamic_mxfp4_quant(M: int, N: int, dtype):
202205

203206
torch.testing.assert_close(triton_scale, torch_scale)
204207
torch.testing.assert_close(triton_out, torch_out)
208+
209+
210+
@pytest.mark.parametrize(
211+
"M, N",
212+
[
213+
(1, 4),
214+
(1, 28),
215+
(1, 32),
216+
(1, 64),
217+
(1, 68),
218+
(2, 4),
219+
(2, 28),
220+
(2, 32),
221+
(2, 64),
222+
(2, 68),
223+
(128, 4),
224+
(128, 28),
225+
(128, 32),
226+
(128, 64),
227+
(128, 68),
228+
(256, 32),
229+
(160, 40),
230+
(280, 20),
231+
],
232+
)
233+
@pytest.mark.parametrize("dtype", [torch.bfloat16])
234+
def test_fp4_utils_dynamic_mxfp4_quant(M: int, N: int, dtype):
235+
torch.cuda.empty_cache()
236+
torch.manual_seed(20)
237+
x = torch.randn((M, N), dtype=dtype, device="cuda")
238+
239+
if DEBUG_MODE:
240+
print(f"x.shape={x.shape} x={x}")
241+
242+
fp4_utils_out, fp4_utils_scale = fp4_utils_dynamic_mxfp4_quant(x)
243+
if DEBUG_MODE:
244+
print(
245+
f"fp4_utils_out.shape={fp4_utils_out.shape} fp4_utils_out={fp4_utils_out}"
246+
)
247+
print(
248+
f"fp4_utils_scale.shape={fp4_utils_scale.shape} fp4_utils_scale={fp4_utils_scale}"
249+
)
250+
251+
torch_out, torch_scale = torch_dynamic_mxfp4_quant(x)
252+
if DEBUG_MODE:
253+
print(f"torch_out.shape={torch_out.shape} torch_out={torch_out}")
254+
print(f"torch_scale.shape={torch_scale.shape} torch_scale={torch_scale}")
255+
256+
torch.testing.assert_close(
257+
fp4_utils_scale.view(torch.uint8).cpu(), torch_scale.cpu()
258+
)
259+
torch.testing.assert_close(fp4_utils_out.view(torch.uint8).cpu(), torch_out.cpu())

0 commit comments

Comments
 (0)