Skip to content

Commit bc234a8

Browse files
committed
fix mx triton kernel after PyTorch triton pin change
Summary: Triton pin updated recently: pytorch/pytorch#126098 In the new triton version, functions can only access global variables of type `tl.constexpr`. Due to the current structure of the code and the fact that these constants are also used by non-triton programs, I think the best thing to do is to just stop using globals in the MX triton kernel. The PR lifts all of these constants to kernel function arguments. Test Plan: ``` pytest test/prototype/mx_formats/test_custom_cast.py ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 664f073 commit bc234a8

File tree

1 file changed

+121
-16
lines changed

1 file changed

+121
-16
lines changed

torchao/prototype/mx_formats/custom_cast.py

Lines changed: 121 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,19 @@ def f6_e3m2_unpacked_to_f32(x: torch.Tensor):
443443
import triton.language as tl
444444

445445
@triton.jit
446-
def _fp4_packed_to_bf16(x_packed):
446+
def _fp4_packed_to_bf16(
447+
x_packed,
448+
sign_mask_f4,
449+
mantissa_mask_f4,
450+
mbits_f4_e2m1,
451+
ebits_f4_e2m1,
452+
f4_e2m1_exp_bias,
453+
mbits_f32,
454+
ebits_f32,
455+
f32_exp_bias,
456+
zero_bits_f32,
457+
zero_point_five_bits_f32,
458+
):
447459
"""
448460
Input: a tensor of packed fp4 values
449461
Output: a tensor of bfloat16 values
@@ -459,7 +471,7 @@ def _fp4_packed_to_bf16(x_packed):
459471
# output = x_unpacked.to(tl.float32)
460472

461473
# save the sign
462-
sign_f4 = x & SIGN_MASK_F4
474+
sign_f4 = x & sign_mask_f4
463475

464476
# set everything to positive, will add sign back at the end
465477
x_pos = x ^ sign_f4
@@ -474,25 +486,25 @@ def _fp4_packed_to_bf16(x_packed):
474486
denormal_mask = x_pos == 1
475487

476488
# calculate the new exponent and shift it to bits 2:9 of the result
477-
exp_biased_f4 = x_pos >> MBITS_F4_E2M1
478-
exp_biased_f32 = exp_biased_f4 - F4_E2M1_EXP_BIAS + F32_EXP_BIAS
479-
exp_biased_f32 = exp_biased_f32.to(tl.int32) << MBITS_F32
489+
exp_biased_f4 = x_pos >> mbits_f4_e2m1
490+
exp_biased_f32 = exp_biased_f4 - f4_e2m1_exp_bias + f32_exp_bias
491+
exp_biased_f32 = exp_biased_f32.to(tl.int32) << mbits_f32
480492

481493
# shift the mantissa to bits 10:32 of the result
482-
mantissa_f4 = x_pos & MANTISSA_MASK_F4
483-
mantissa_f32 = mantissa_f4.to(tl.int32) << (MBITS_F32 - MBITS_F4_E2M1)
494+
mantissa_f4 = x_pos & mantissa_mask_f4
495+
mantissa_f32 = mantissa_f4.to(tl.int32) << (mbits_f32 - mbits_f4_e2m1)
484496
output = mantissa_f32
485497

486498
# combine the pieces
487499
result = exp_biased_f32 | mantissa_f32
488500
# result[zero_mask] = ZERO_BITS_F32
489-
result = tl.where(zero_mask, ZERO_BITS_F32, result)
501+
result = tl.where(zero_mask, zero_bits_f32, result)
490502
# result[denormal_mask] = ZERO_POINT_FIVE_BITS_F32
491-
result = tl.where(denormal_mask, ZERO_POINT_FIVE_BITS_F32, result)
503+
result = tl.where(denormal_mask, zero_point_five_bits_f32, result)
492504

493505
# add sign back
494506
sign_f32 = sign_f4.to(tl.int32) << (
495-
MBITS_F32 - MBITS_F4_E2M1 + EBITS_F32 - EBITS_F4_E2M1
507+
mbits_f32 - mbits_f4_e2m1 + ebits_f32 - ebits_f4_e2m1
496508
)
497509
result = result | sign_f32
498510

@@ -510,6 +522,16 @@ def triton_f4_to_bf16_kernel(
510522
x_ptr,
511523
output_ptr,
512524
n_elements_in,
525+
sign_mask_f4: tl.constexpr,
526+
mantissa_mask_f4: tl.constexpr,
527+
mbits_f4_e2m1: tl.constexpr,
528+
ebits_f4_e2m1: tl.constexpr,
529+
f4_e2m1_exp_bias: tl.constexpr,
530+
mbits_f32: tl.constexpr,
531+
ebits_f32: tl.constexpr,
532+
f32_exp_bias: tl.constexpr,
533+
zero_bits_f32: tl.constexpr,
534+
zero_point_five_bits_f32: tl.constexpr,
513535
BLOCK_SIZE_IN: tl.constexpr,
514536
):
515537
pid = tl.program_id(axis=0)
@@ -523,7 +545,19 @@ def triton_f4_to_bf16_kernel(
523545

524546
# packed uint8
525547
x_packed = tl.load(x_ptr + offsets_in, mask=mask_in)
526-
output = _fp4_packed_to_bf16(x_packed)
548+
output = _fp4_packed_to_bf16(
549+
x_packed,
550+
sign_mask_f4,
551+
mantissa_mask_f4,
552+
mbits_f4_e2m1,
553+
ebits_f4_e2m1,
554+
f4_e2m1_exp_bias,
555+
mbits_f32,
556+
ebits_f32,
557+
f32_exp_bias,
558+
zero_bits_f32,
559+
zero_point_five_bits_f32,
560+
)
527561

528562
# set up output offsets
529563
block_start_out = pid * BLOCK_SIZE_OUT
@@ -549,6 +583,18 @@ def triton_f4_to_scaled_bf16_kernel(
549583
output_ptr,
550584
n_elements_in,
551585
mx_block_size: tl.constexpr,
586+
sign_mask_f4: tl.constexpr,
587+
mantissa_mask_f4: tl.constexpr,
588+
mbits_f4_e2m1: tl.constexpr,
589+
ebits_f4_e2m1: tl.constexpr,
590+
f4_e2m1_exp_bias: tl.constexpr,
591+
mbits_f32: tl.constexpr,
592+
ebits_f32: tl.constexpr,
593+
f32_exp_bias: tl.constexpr,
594+
zero_bits_f32: tl.constexpr,
595+
zero_point_five_bits_f32: tl.constexpr,
596+
e8m0_exponent_bias: tl.constexpr,
597+
e8m0_exponent_nan_val: tl.constexpr,
552598
BLOCK_SIZE_IN: tl.constexpr,
553599
):
554600
pid = tl.program_id(axis=0)
@@ -563,7 +609,19 @@ def triton_f4_to_scaled_bf16_kernel(
563609
mask_in = offsets_in < n_elements_in
564610
# packed uint8
565611
x_packed = tl.load(x_ptr + offsets_in, mask=mask_in)
566-
output = _fp4_packed_to_bf16(x_packed)
612+
output = _fp4_packed_to_bf16(
613+
x_packed,
614+
sign_mask_f4,
615+
mantissa_mask_f4,
616+
mbits_f4_e2m1,
617+
ebits_f4_e2m1,
618+
f4_e2m1_exp_bias,
619+
mbits_f32,
620+
ebits_f32,
621+
f32_exp_bias,
622+
zero_bits_f32,
623+
zero_point_five_bits_f32,
624+
)
567625

568626
# load scale
569627
block_start_s = pid * BLOCK_SIZE_S
@@ -572,9 +630,9 @@ def triton_f4_to_scaled_bf16_kernel(
572630
s = tl.load(s_ptr + offsets_s, mask=mask_s)
573631

574632
# create the scale in bf16
575-
s_offset = s.to(tl.int16) - E8M0_EXPONENT_BIAS
633+
s_offset = s.to(tl.int16) - e8m0_exponent_bias
576634
s_fp = libdevice.pow(2.0, s_offset).to(tl.bfloat16)
577-
s_fp = tl.where(s != E8M0_EXPONENT_NAN_VAL, s_fp, float("nan"))
635+
s_fp = tl.where(s != e8m0_exponent_nan_val, s_fp, float("nan"))
578636

579637
# multiply output by scale
580638
# TODO(later): see if manipulating the exponent instead of fp
@@ -599,6 +657,16 @@ def triton_f4_to_bf16_kernel(
599657
x_ptr,
600658
output_ptr,
601659
n_elements_in,
660+
sign_mask_f4,
661+
mantissa_mask_f4,
662+
mbits_f4_e2m1,
663+
ebits_f4_e2m1,
664+
f4_e2m1_exp_bias,
665+
mbits_f32,
666+
ebits_f32,
667+
f32_exp_bias,
668+
zero_bits_f32,
669+
zero_point_five_bits_f32,
602670
BLOCK_SIZE_IN,
603671
):
604672
raise AssertionError("unsupported without triton")
@@ -609,6 +677,18 @@ def triton_f4_to_scaled_bf16_kernel(
609677
output_ptr,
610678
n_elements_in,
611679
mx_block_size,
680+
sign_mask_f4,
681+
mantissa_mask_f4,
682+
mbits_f4_e2m1,
683+
ebits_f4_e2m1,
684+
f4_e2m1_exp_bias,
685+
mbits_f32,
686+
ebits_f32,
687+
f32_exp_bias,
688+
zero_bits_f32,
689+
zero_point_five_bits_f32,
690+
e8m0_exponent_bias,
691+
e8m0_exponent_nan_val,
612692
BLOCK_SIZE_IN,
613693
):
614694
raise AssertionError("unsupported without triton")
@@ -630,7 +710,20 @@ def triton_f4_to_bf16(x: torch.Tensor):
630710
grid = lambda meta: ( # noqa: E731
631711
triton.cdiv(n_elements_in, meta["BLOCK_SIZE_IN"]),
632712
) # noqa: E731,E501
633-
triton_f4_to_bf16_kernel[grid](x, output, n_elements_in, BLOCK_SIZE_IN=512)
713+
triton_f4_to_bf16_kernel[grid](
714+
x, output, n_elements_in,
715+
sign_mask_f4=SIGN_MASK_F4,
716+
mantissa_mask_f4=MANTISSA_MASK_F4,
717+
mbits_f4_e2m1=MBITS_F4_E2M1,
718+
ebits_f4_e2m1=EBITS_F4_E2M1,
719+
f4_e2m1_exp_bias=F4_E2M1_EXP_BIAS,
720+
mbits_f32=MBITS_F32,
721+
ebits_f32=EBITS_F32,
722+
f32_exp_bias=F32_EXP_BIAS,
723+
zero_bits_f32=ZERO_BITS_F32,
724+
zero_point_five_bits_f32=ZERO_POINT_FIVE_BITS_F32,
725+
BLOCK_SIZE_IN=512,
726+
)
634727
return output
635728

636729

@@ -654,7 +747,19 @@ def triton_f4_to_scaled_bf16(
654747
triton.cdiv(n_elements_in, meta["BLOCK_SIZE_IN"]),
655748
)
656749
triton_f4_to_scaled_bf16_kernel[grid](
657-
x, s_e8m0, output, n_elements_in, mx_block_size
750+
x, s_e8m0, output, n_elements_in, mx_block_size,
751+
sign_mask_f4=SIGN_MASK_F4,
752+
mantissa_mask_f4=MANTISSA_MASK_F4,
753+
mbits_f4_e2m1=MBITS_F4_E2M1,
754+
ebits_f4_e2m1=EBITS_F4_E2M1,
755+
f4_e2m1_exp_bias=F4_E2M1_EXP_BIAS,
756+
mbits_f32=MBITS_F32,
757+
ebits_f32=EBITS_F32,
758+
f32_exp_bias=F32_EXP_BIAS,
759+
zero_bits_f32=ZERO_BITS_F32,
760+
zero_point_five_bits_f32=ZERO_POINT_FIVE_BITS_F32,
761+
e8m0_exponent_bias=E8M0_EXPONENT_BIAS,
762+
e8m0_exponent_nan_val=E8M0_EXPONENT_NAN_VAL,
658763
)
659764
return output
660765

0 commit comments

Comments
 (0)