@@ -296,30 +296,63 @@ def _dynamic_mxfp4_quant_kernel_asm_layout(
296296 # S101 -> +/- 3.0
297297 # S110 -> +/- 4.0
298298 # S111 -> +/- 6.0
299+ # FP4 format constants
300+ EXP_BIAS_FP32 : tl .constexpr = 127
301+ EXP_BIAS_FP4 : tl .constexpr = 1
302+ EBITS_F32 : tl .constexpr = 8
303+ EBITS_FP4 : tl .constexpr = 2
304+ MBITS_F32 : tl .constexpr = 23
305+ MBITS_FP4 : tl .constexpr = 1
306+
307+ max_normal : tl .constexpr = 6
308+ min_normal : tl .constexpr = 1
309+
299310 qx = qx .to (tl .uint32 , bitcast = True )
300311
301- # Extract sign, exponents and mantissa fields from FP32
312+ # Extract sign
302313 s = qx & 0x80000000
303- e = ( qx >> 23 ) & 0xFF
304- m = qx & 0x7FFFFF
314+ # Set everything to positive, will add sign back at the end
315+ qx = qx ^ s
305316
306- E8_BIAS : tl .constexpr = 127
307- E2_BIAS : tl .constexpr = 1
317+ qx_fp32 = qx .to (tl .float32 , bitcast = True )
318+ saturate_mask = qx_fp32 >= max_normal
319+ denormal_mask = (not saturate_mask ) & (qx_fp32 < min_normal )
320+ normal_mask = not (saturate_mask | denormal_mask )
308321
309322 # Denormal numbers
310- # If exponent is less than 127, then it's a denormal number
311- # See above, for denormal number mantissa is always 1 and we set bit 1 of mantissa
312- adjusted_exponents = tl .core .sub (E8_BIAS , e + 1 , sanitize_overflow = False )
313- m = tl .where (e < E8_BIAS , (0x400000 | (m >> 1 )) >> adjusted_exponents , m )
323+ denorm_exp : tl .constexpr = (
324+ (EXP_BIAS_FP32 - EXP_BIAS_FP4 ) + (MBITS_F32 - MBITS_FP4 ) + 1
325+ )
326+ denorm_mask_int : tl .constexpr = denorm_exp << MBITS_F32
327+ denorm_mask_float : tl .constexpr = tl .cast (denorm_mask_int , tl .float32 , bitcast = True )
328+
329+ denormal_x = qx_fp32 + denorm_mask_float
330+ denormal_x = denormal_x .to (tl .uint32 , bitcast = True )
331+ denormal_x -= denorm_mask_int
332+ denormal_x = denormal_x .to (tl .uint8 )
333+
334+ # Normal numbers
335+ normal_x = qx
336+ # resulting mantissa is odd
337+ mant_odd = (normal_x >> (MBITS_F32 - MBITS_FP4 )) & 1
338+ # update exponent, rounding bias part 1
339+ val_to_add = ((EXP_BIAS_FP4 - EXP_BIAS_FP32 ) << MBITS_F32 ) + (1 << 21 ) - 1
340+ normal_x += val_to_add
341+ # rounding bias part 2
342+ normal_x += mant_odd
343+ # take the bits!
344+ normal_x = normal_x >> (MBITS_F32 - MBITS_FP4 )
345+ normal_x = normal_x .to (tl .uint8 )
314346
315- # For normal numbers, bias is changed from 127 to 1, and for subnormals, we keep exponent as 0.
316- # Note: E8_BIAS - E2_BIAS = 126, so for normals we subtract that.
317- e = tl .maximum (e , E8_BIAS - E2_BIAS ) - (E8_BIAS - E2_BIAS )
347+ # Merge results
348+ e2m1_value = tl .full (qx .type .get_block_shapes (), 0x7 , dtype = tl .uint8 )
349+ e2m1_value = tl .where (normal_mask , normal_x , e2m1_value )
350+ e2m1_value = tl .where (denormal_mask , denormal_x , e2m1_value )
318351
319- # Combine sign, exponent, and mantissa, while saturating
320- # rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
321- e2m1_tmp = tl . minimum (((( e << 2 ) | ( m >> 21 )) + 1 ) >> 1 , 0x7 )
322- e2m1_value = (( s >> 28 ) | e2m1_tmp ). to ( tl . uint8 )
352+ # add sign back
353+ sign_lp = s >> ( MBITS_F32 + EBITS_F32 - MBITS_FP4 - EBITS_FP4 )
354+ sign_lp = sign_lp . to ( tl . uint8 )
355+ e2m1_value = e2m1_value | sign_lp
323356
324357 e2m1_value = tl .reshape (e2m1_value , [BLOCK_SIZE , MXFP4_QUANT_BLOCK_SIZE // 2 , 2 ])
325358 evens , odds = tl .split (e2m1_value )
@@ -422,6 +455,10 @@ def dynamic_mxfp4_quant(
422455 SHUFFLE = shuffle ,
423456 )
424457
458+ if not shuffle :
459+ # Trim the padding if not shuffled
460+ blockscale_e8m0 = blockscale_e8m0 [:M , :scaleN_valid ].contiguous ()
461+
425462 return (x_fp4 .view (dtypes .fp4x2 ), blockscale_e8m0 .view (dtypes .fp8_e8m0 ))
426463
427464
0 commit comments