Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
gpu-arch-version: "12.1"
- name: CUDA Nightly
runs-on: linux.g5.12xlarge.nvidia.gpu
torch-spec: '--pre torch==2.5.0.dev20240620+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121'
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"
- name: CPU 2.2.2
Expand Down
143 changes: 127 additions & 16 deletions torchao/prototype/mx_formats/custom_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,19 @@ def f6_e3m2_unpacked_to_f32(x: torch.Tensor):
import triton.language as tl

@triton.jit
def _fp4_packed_to_bf16(x_packed):
def _fp4_packed_to_bf16(
x_packed,
sign_mask_f4,
mantissa_mask_f4,
mbits_f4_e2m1,
ebits_f4_e2m1,
f4_e2m1_exp_bias,
mbits_f32,
ebits_f32,
f32_exp_bias,
zero_bits_f32,
zero_point_five_bits_f32,
):
"""
Input: a tensor of packed fp4 values
Output: a tensor of bfloat16 values
Expand All @@ -123,7 +135,7 @@ def _fp4_packed_to_bf16(x_packed):
# output = x_unpacked.to(tl.float32)

# save the sign
sign_f4 = x & SIGN_MASK_F4
sign_f4 = x & sign_mask_f4

# set everything to positive, will add sign back at the end
x_pos = x ^ sign_f4
Expand All @@ -138,25 +150,25 @@ def _fp4_packed_to_bf16(x_packed):
denormal_mask = x_pos == 1

# calculate the new exponent and shift it to bits 2:9 of the result
exp_biased_f4 = x_pos >> MBITS_F4_E2M1
exp_biased_f32 = exp_biased_f4 - F4_E2M1_EXP_BIAS + F32_EXP_BIAS
exp_biased_f32 = exp_biased_f32.to(tl.int32) << MBITS_F32
exp_biased_f4 = x_pos >> mbits_f4_e2m1
exp_biased_f32 = exp_biased_f4 - f4_e2m1_exp_bias + f32_exp_bias
exp_biased_f32 = exp_biased_f32.to(tl.int32) << mbits_f32

# shift the mantissa to bits 10:32 of the result
mantissa_f4 = x_pos & MANTISSA_MASK_F4
mantissa_f32 = mantissa_f4.to(tl.int32) << (MBITS_F32 - MBITS_F4_E2M1)
mantissa_f4 = x_pos & mantissa_mask_f4
mantissa_f32 = mantissa_f4.to(tl.int32) << (mbits_f32 - mbits_f4_e2m1)
output = mantissa_f32

# combine the pieces
result = exp_biased_f32 | mantissa_f32
# result[zero_mask] = ZERO_BITS_F32
result = tl.where(zero_mask, ZERO_BITS_F32, result)
result = tl.where(zero_mask, zero_bits_f32, result)
# result[denormal_mask] = ZERO_POINT_FIVE_BITS_F32
result = tl.where(denormal_mask, ZERO_POINT_FIVE_BITS_F32, result)
result = tl.where(denormal_mask, zero_point_five_bits_f32, result)

# add sign back
sign_f32 = sign_f4.to(tl.int32) << (
MBITS_F32 - MBITS_F4_E2M1 + EBITS_F32 - EBITS_F4_E2M1
mbits_f32 - mbits_f4_e2m1 + ebits_f32 - ebits_f4_e2m1
)
result = result | sign_f32

Expand All @@ -174,6 +186,16 @@ def triton_f4_to_bf16_kernel(
x_ptr,
output_ptr,
n_elements_in,
sign_mask_f4: tl.constexpr,
mantissa_mask_f4: tl.constexpr,
mbits_f4_e2m1: tl.constexpr,
ebits_f4_e2m1: tl.constexpr,
f4_e2m1_exp_bias: tl.constexpr,
mbits_f32: tl.constexpr,
ebits_f32: tl.constexpr,
f32_exp_bias: tl.constexpr,
zero_bits_f32: tl.constexpr,
zero_point_five_bits_f32: tl.constexpr,
BLOCK_SIZE_IN: tl.constexpr,
):
pid = tl.program_id(axis=0)
Expand All @@ -187,7 +209,19 @@ def triton_f4_to_bf16_kernel(

# packed uint8
x_packed = tl.load(x_ptr + offsets_in, mask=mask_in)
output = _fp4_packed_to_bf16(x_packed)
output = _fp4_packed_to_bf16(
x_packed,
sign_mask_f4,
mantissa_mask_f4,
mbits_f4_e2m1,
ebits_f4_e2m1,
f4_e2m1_exp_bias,
mbits_f32,
ebits_f32,
f32_exp_bias,
zero_bits_f32,
zero_point_five_bits_f32,
)

# set up output offsets
block_start_out = pid * BLOCK_SIZE_OUT
Expand All @@ -213,6 +247,18 @@ def triton_f4_to_scaled_bf16_kernel(
output_ptr,
n_elements_in,
mx_block_size: tl.constexpr,
sign_mask_f4: tl.constexpr,
mantissa_mask_f4: tl.constexpr,
mbits_f4_e2m1: tl.constexpr,
ebits_f4_e2m1: tl.constexpr,
f4_e2m1_exp_bias: tl.constexpr,
mbits_f32: tl.constexpr,
ebits_f32: tl.constexpr,
f32_exp_bias: tl.constexpr,
zero_bits_f32: tl.constexpr,
zero_point_five_bits_f32: tl.constexpr,
e8m0_exponent_bias: tl.constexpr,
e8m0_exponent_nan_val: tl.constexpr,
BLOCK_SIZE_IN: tl.constexpr,
):
pid = tl.program_id(axis=0)
Expand All @@ -227,7 +273,19 @@ def triton_f4_to_scaled_bf16_kernel(
mask_in = offsets_in < n_elements_in
# packed uint8
x_packed = tl.load(x_ptr + offsets_in, mask=mask_in)
output = _fp4_packed_to_bf16(x_packed)
output = _fp4_packed_to_bf16(
x_packed,
sign_mask_f4,
mantissa_mask_f4,
mbits_f4_e2m1,
ebits_f4_e2m1,
f4_e2m1_exp_bias,
mbits_f32,
ebits_f32,
f32_exp_bias,
zero_bits_f32,
zero_point_five_bits_f32,
)

# load scale
block_start_s = pid * BLOCK_SIZE_S
Expand All @@ -236,9 +294,9 @@ def triton_f4_to_scaled_bf16_kernel(
s = tl.load(s_ptr + offsets_s, mask=mask_s)

# create the scale in bf16
s_offset = s.to(tl.int16) - E8M0_EXPONENT_BIAS
s_offset = s.to(tl.int16) - e8m0_exponent_bias
s_fp = libdevice.pow(2.0, s_offset).to(tl.bfloat16)
s_fp = tl.where(s != E8M0_EXPONENT_NAN_VAL, s_fp, float("nan"))
s_fp = tl.where(s != e8m0_exponent_nan_val, s_fp, float("nan"))

# multiply output by scale
# TODO(later): see if manipulating the exponent instead of fp
Expand All @@ -263,6 +321,16 @@ def triton_f4_to_bf16_kernel(
x_ptr,
output_ptr,
n_elements_in,
sign_mask_f4,
mantissa_mask_f4,
mbits_f4_e2m1,
ebits_f4_e2m1,
f4_e2m1_exp_bias,
mbits_f32,
ebits_f32,
f32_exp_bias,
zero_bits_f32,
zero_point_five_bits_f32,
BLOCK_SIZE_IN,
):
raise AssertionError("unsupported without triton")
Expand All @@ -273,6 +341,18 @@ def triton_f4_to_scaled_bf16_kernel(
output_ptr,
n_elements_in,
mx_block_size,
sign_mask_f4,
mantissa_mask_f4,
mbits_f4_e2m1,
ebits_f4_e2m1,
f4_e2m1_exp_bias,
mbits_f32,
ebits_f32,
f32_exp_bias,
zero_bits_f32,
zero_point_five_bits_f32,
e8m0_exponent_bias,
e8m0_exponent_nan_val,
BLOCK_SIZE_IN,
):
raise AssertionError("unsupported without triton")
Expand All @@ -294,7 +374,22 @@ def triton_f4_to_bf16(x: torch.Tensor):
grid = lambda meta: ( # noqa: E731
triton.cdiv(n_elements_in, meta["BLOCK_SIZE_IN"]),
) # noqa: E731,E501
triton_f4_to_bf16_kernel[grid](x, output, n_elements_in, BLOCK_SIZE_IN=512)
triton_f4_to_bf16_kernel[grid](
x,
output,
n_elements_in,
sign_mask_f4=SIGN_MASK_F4,
mantissa_mask_f4=MANTISSA_MASK_F4,
mbits_f4_e2m1=MBITS_F4_E2M1,
ebits_f4_e2m1=EBITS_F4_E2M1,
f4_e2m1_exp_bias=F4_E2M1_EXP_BIAS,
mbits_f32=MBITS_F32,
ebits_f32=EBITS_F32,
f32_exp_bias=F32_EXP_BIAS,
zero_bits_f32=ZERO_BITS_F32,
zero_point_five_bits_f32=ZERO_POINT_FIVE_BITS_F32,
BLOCK_SIZE_IN=512,
)
return output


Expand All @@ -318,7 +413,23 @@ def triton_f4_to_scaled_bf16(
triton.cdiv(n_elements_in, meta["BLOCK_SIZE_IN"]),
)
triton_f4_to_scaled_bf16_kernel[grid](
x, s_e8m0, output, n_elements_in, mx_block_size
x,
s_e8m0,
output,
n_elements_in,
mx_block_size,
sign_mask_f4=SIGN_MASK_F4,
mantissa_mask_f4=MANTISSA_MASK_F4,
mbits_f4_e2m1=MBITS_F4_E2M1,
ebits_f4_e2m1=EBITS_F4_E2M1,
f4_e2m1_exp_bias=F4_E2M1_EXP_BIAS,
mbits_f32=MBITS_F32,
ebits_f32=EBITS_F32,
f32_exp_bias=F32_EXP_BIAS,
zero_bits_f32=ZERO_BITS_F32,
zero_point_five_bits_f32=ZERO_POINT_FIVE_BITS_F32,
e8m0_exponent_bias=E8M0_EXPONENT_BIAS,
e8m0_exponent_nan_val=E8M0_EXPONENT_NAN_VAL,
)
return output

Expand Down