@@ -443,7 +443,19 @@ def f6_e3m2_unpacked_to_f32(x: torch.Tensor):
443443 import triton .language as tl
444444
445445 @triton .jit
446- def _fp4_packed_to_bf16 (x_packed ):
446+ def _fp4_packed_to_bf16 (
447+ x_packed ,
448+ sign_mask_f4 ,
449+ mantissa_mask_f4 ,
450+ mbits_f4_e2m1 ,
451+ ebits_f4_e2m1 ,
452+ f4_e2m1_exp_bias ,
453+ mbits_f32 ,
454+ ebits_f32 ,
455+ f32_exp_bias ,
456+ zero_bits_f32 ,
457+ zero_point_five_bits_f32 ,
458+ ):
447459 """
448460 Input: a tensor of packed fp4 values
449461 Output: a tensor of bfloat16 values
@@ -459,7 +471,7 @@ def _fp4_packed_to_bf16(x_packed):
459471 # output = x_unpacked.to(tl.float32)
460472
461473 # save the sign
462- sign_f4 = x & SIGN_MASK_F4
474+ sign_f4 = x & sign_mask_f4
463475
464476 # set everything to positive, will add sign back at the end
465477 x_pos = x ^ sign_f4
@@ -474,25 +486,25 @@ def _fp4_packed_to_bf16(x_packed):
474486 denormal_mask = x_pos == 1
475487
476488 # calculate the new exponent and shift it to bits 2:9 of the result
477- exp_biased_f4 = x_pos >> MBITS_F4_E2M1
478- exp_biased_f32 = exp_biased_f4 - F4_E2M1_EXP_BIAS + F32_EXP_BIAS
479- exp_biased_f32 = exp_biased_f32 .to (tl .int32 ) << MBITS_F32
489+ exp_biased_f4 = x_pos >> mbits_f4_e2m1
490+ exp_biased_f32 = exp_biased_f4 - f4_e2m1_exp_bias + f32_exp_bias
491+ exp_biased_f32 = exp_biased_f32 .to (tl .int32 ) << mbits_f32
480492
481493 # shift the mantissa to bits 10:32 of the result
482- mantissa_f4 = x_pos & MANTISSA_MASK_F4
483- mantissa_f32 = mantissa_f4 .to (tl .int32 ) << (MBITS_F32 - MBITS_F4_E2M1 )
494+ mantissa_f4 = x_pos & mantissa_mask_f4
495+ mantissa_f32 = mantissa_f4 .to (tl .int32 ) << (mbits_f32 - mbits_f4_e2m1 )
484496 output = mantissa_f32
485497
486498 # combine the pieces
487499 result = exp_biased_f32 | mantissa_f32
488500 # result[zero_mask] = ZERO_BITS_F32
489- result = tl .where (zero_mask , ZERO_BITS_F32 , result )
501+ result = tl .where (zero_mask , zero_bits_f32 , result )
490502 # result[denormal_mask] = ZERO_POINT_FIVE_BITS_F32
491- result = tl .where (denormal_mask , ZERO_POINT_FIVE_BITS_F32 , result )
503+ result = tl .where (denormal_mask , zero_point_five_bits_f32 , result )
492504
493505 # add sign back
494506 sign_f32 = sign_f4 .to (tl .int32 ) << (
495- MBITS_F32 - MBITS_F4_E2M1 + EBITS_F32 - EBITS_F4_E2M1
507+ mbits_f32 - mbits_f4_e2m1 + ebits_f32 - ebits_f4_e2m1
496508 )
497509 result = result | sign_f32
498510
@@ -510,6 +522,16 @@ def triton_f4_to_bf16_kernel(
510522 x_ptr ,
511523 output_ptr ,
512524 n_elements_in ,
525+ sign_mask_f4 : tl .constexpr ,
526+ mantissa_mask_f4 : tl .constexpr ,
527+ mbits_f4_e2m1 : tl .constexpr ,
528+ ebits_f4_e2m1 : tl .constexpr ,
529+ f4_e2m1_exp_bias : tl .constexpr ,
530+ mbits_f32 : tl .constexpr ,
531+ ebits_f32 : tl .constexpr ,
532+ f32_exp_bias : tl .constexpr ,
533+ zero_bits_f32 : tl .constexpr ,
534+ zero_point_five_bits_f32 : tl .constexpr ,
513535 BLOCK_SIZE_IN : tl .constexpr ,
514536 ):
515537 pid = tl .program_id (axis = 0 )
@@ -523,7 +545,19 @@ def triton_f4_to_bf16_kernel(
523545
524546 # packed uint8
525547 x_packed = tl .load (x_ptr + offsets_in , mask = mask_in )
526- output = _fp4_packed_to_bf16 (x_packed )
548+ output = _fp4_packed_to_bf16 (
549+ x_packed ,
550+ sign_mask_f4 ,
551+ mantissa_mask_f4 ,
552+ mbits_f4_e2m1 ,
553+ ebits_f4_e2m1 ,
554+ f4_e2m1_exp_bias ,
555+ mbits_f32 ,
556+ ebits_f32 ,
557+ f32_exp_bias ,
558+ zero_bits_f32 ,
559+ zero_point_five_bits_f32 ,
560+ )
527561
528562 # set up output offsets
529563 block_start_out = pid * BLOCK_SIZE_OUT
@@ -549,6 +583,18 @@ def triton_f4_to_scaled_bf16_kernel(
549583 output_ptr ,
550584 n_elements_in ,
551585 mx_block_size : tl .constexpr ,
586+ sign_mask_f4 : tl .constexpr ,
587+ mantissa_mask_f4 : tl .constexpr ,
588+ mbits_f4_e2m1 : tl .constexpr ,
589+ ebits_f4_e2m1 : tl .constexpr ,
590+ f4_e2m1_exp_bias : tl .constexpr ,
591+ mbits_f32 : tl .constexpr ,
592+ ebits_f32 : tl .constexpr ,
593+ f32_exp_bias : tl .constexpr ,
594+ zero_bits_f32 : tl .constexpr ,
595+ zero_point_five_bits_f32 : tl .constexpr ,
596+ e8m0_exponent_bias : tl .constexpr ,
597+ e8m0_exponent_nan_val : tl .constexpr ,
552598 BLOCK_SIZE_IN : tl .constexpr ,
553599 ):
554600 pid = tl .program_id (axis = 0 )
@@ -563,7 +609,19 @@ def triton_f4_to_scaled_bf16_kernel(
563609 mask_in = offsets_in < n_elements_in
564610 # packed uint8
565611 x_packed = tl .load (x_ptr + offsets_in , mask = mask_in )
566- output = _fp4_packed_to_bf16 (x_packed )
612+ output = _fp4_packed_to_bf16 (
613+ x_packed ,
614+ sign_mask_f4 ,
615+ mantissa_mask_f4 ,
616+ mbits_f4_e2m1 ,
617+ ebits_f4_e2m1 ,
618+ f4_e2m1_exp_bias ,
619+ mbits_f32 ,
620+ ebits_f32 ,
621+ f32_exp_bias ,
622+ zero_bits_f32 ,
623+ zero_point_five_bits_f32 ,
624+ )
567625
568626 # load scale
569627 block_start_s = pid * BLOCK_SIZE_S
@@ -572,9 +630,9 @@ def triton_f4_to_scaled_bf16_kernel(
572630 s = tl .load (s_ptr + offsets_s , mask = mask_s )
573631
574632 # create the scale in bf16
575- s_offset = s .to (tl .int16 ) - E8M0_EXPONENT_BIAS
633+ s_offset = s .to (tl .int16 ) - e8m0_exponent_bias
576634 s_fp = libdevice .pow (2.0 , s_offset ).to (tl .bfloat16 )
577- s_fp = tl .where (s != E8M0_EXPONENT_NAN_VAL , s_fp , float ("nan" ))
635+ s_fp = tl .where (s != e8m0_exponent_nan_val , s_fp , float ("nan" ))
578636
579637 # multiply output by scale
580638 # TODO(later): see if manipulating the exponent instead of fp
@@ -599,6 +657,16 @@ def triton_f4_to_bf16_kernel(
599657 x_ptr ,
600658 output_ptr ,
601659 n_elements_in ,
660+ sign_mask_f4 ,
661+ mantissa_mask_f4 ,
662+ mbits_f4_e2m1 ,
663+ ebits_f4_e2m1 ,
664+ f4_e2m1_exp_bias ,
665+ mbits_f32 ,
666+ ebits_f32 ,
667+ f32_exp_bias ,
668+ zero_bits_f32 ,
669+ zero_point_five_bits_f32 ,
602670 BLOCK_SIZE_IN ,
603671 ):
604672 raise AssertionError ("unsupported without triton" )
@@ -609,6 +677,18 @@ def triton_f4_to_scaled_bf16_kernel(
609677 output_ptr ,
610678 n_elements_in ,
611679 mx_block_size ,
680+ sign_mask_f4 ,
681+ mantissa_mask_f4 ,
682+ mbits_f4_e2m1 ,
683+ ebits_f4_e2m1 ,
684+ f4_e2m1_exp_bias ,
685+ mbits_f32 ,
686+ ebits_f32 ,
687+ f32_exp_bias ,
688+ zero_bits_f32 ,
689+ zero_point_five_bits_f32 ,
690+ e8m0_exponent_bias ,
691+ e8m0_exponent_nan_val ,
612692 BLOCK_SIZE_IN ,
613693 ):
614694 raise AssertionError ("unsupported without triton" )
@@ -630,7 +710,22 @@ def triton_f4_to_bf16(x: torch.Tensor):
630710 grid = lambda meta : ( # noqa: E731
631711 triton .cdiv (n_elements_in , meta ["BLOCK_SIZE_IN" ]),
632712 ) # noqa: E731,E501
633- triton_f4_to_bf16_kernel [grid ](x , output , n_elements_in , BLOCK_SIZE_IN = 512 )
713+ triton_f4_to_bf16_kernel [grid ](
714+ x ,
715+ output ,
716+ n_elements_in ,
717+ sign_mask_f4 = SIGN_MASK_F4 ,
718+ mantissa_mask_f4 = MANTISSA_MASK_F4 ,
719+ mbits_f4_e2m1 = MBITS_F4_E2M1 ,
720+ ebits_f4_e2m1 = EBITS_F4_E2M1 ,
721+ f4_e2m1_exp_bias = F4_E2M1_EXP_BIAS ,
722+ mbits_f32 = MBITS_F32 ,
723+ ebits_f32 = EBITS_F32 ,
724+ f32_exp_bias = F32_EXP_BIAS ,
725+ zero_bits_f32 = ZERO_BITS_F32 ,
726+ zero_point_five_bits_f32 = ZERO_POINT_FIVE_BITS_F32 ,
727+ BLOCK_SIZE_IN = 512 ,
728+ )
634729 return output
635730
636731
@@ -654,7 +749,23 @@ def triton_f4_to_scaled_bf16(
654749 triton .cdiv (n_elements_in , meta ["BLOCK_SIZE_IN" ]),
655750 )
656751 triton_f4_to_scaled_bf16_kernel [grid ](
657- x , s_e8m0 , output , n_elements_in , mx_block_size
752+ x ,
753+ s_e8m0 ,
754+ output ,
755+ n_elements_in ,
756+ mx_block_size ,
757+ sign_mask_f4 = SIGN_MASK_F4 ,
758+ mantissa_mask_f4 = MANTISSA_MASK_F4 ,
759+ mbits_f4_e2m1 = MBITS_F4_E2M1 ,
760+ ebits_f4_e2m1 = EBITS_F4_E2M1 ,
761+ f4_e2m1_exp_bias = F4_E2M1_EXP_BIAS ,
762+ mbits_f32 = MBITS_F32 ,
763+ ebits_f32 = EBITS_F32 ,
764+ f32_exp_bias = F32_EXP_BIAS ,
765+ zero_bits_f32 = ZERO_BITS_F32 ,
766+ zero_point_five_bits_f32 = ZERO_POINT_FIVE_BITS_F32 ,
767+ e8m0_exponent_bias = E8M0_EXPONENT_BIAS ,
768+ e8m0_exponent_nan_val = E8M0_EXPONENT_NAN_VAL ,
658769 )
659770 return output
660771
0 commit comments