Skip to content

Commit 8236a87

Browse files
authored
Renaming fpx to floatx (#877)
* Renaming fpx to floatx Summary: att, to allow float8 code to be moved to floatx folder fpx_weight_only is not yet renamed to floatx_weight_only yet, we'll do that in the future after we have more clarity on what specific dtypes we want to support (e.g. maybe we'll just support fp4, fp6) Test Plan: python test/dtypes/test_floatx.py Reviewers: Subscribers: Tasks: Tags: * fix test_ops
1 parent f82071d commit 8236a87

17 files changed

+188
-188
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ The best example we have combining the composability of lower bit dtype with com
128128

129129
We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()` so if you love writing kernels but hate packaging them so they work all operating systems and cuda versions, we'd love to accept contributions for your custom ops. We have a few examples you can follow
130130

131-
1. [fp6](torchao/prototype/quant_llm/) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fp6_llm_weight_only())`
131+
1. [fp6](torchao/dtypes/floatx) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))`
132132
2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256
133133
3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference
134134

benchmarks/benchmark_fp6.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import torch
22
import pandas as pd
33
import torch.nn.functional as F
4-
from torchao.dtypes import to_affine_quantized_fpx
5-
from torchao.dtypes.fpx import FpxTensorCoreAQTLayout, FpxTensorCoreLayoutType
4+
from torchao.dtypes import to_affine_quantized_floatx
5+
from torchao.dtypes.floatx import FloatxTensorCoreAQTLayout, FloatxTensorCoreLayoutType
66
from torchao.utils import benchmark_torch_function_in_microseconds
77
from tqdm import tqdm
88

99

1010
def benchmark(m: int, k: int, n: int):
1111
float_data = torch.randn(n, k, dtype=torch.half, device="cuda")
12-
fp6_weight = to_affine_quantized_fpx(float_data, FpxTensorCoreLayoutType(3, 2))
12+
fp6_weight = to_affine_quantized_floatx(float_data, FloatxTensorCoreLayoutType(3, 2))
1313
fp16_weight = fp6_weight.dequantize(torch.half)
1414

1515
fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda")

test/dtypes/test_fpx.py renamed to test/dtypes/test_floatx.py

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
parametrize,
99
run_tests,
1010
)
11-
from torchao.dtypes.fpx import (
12-
FpxTensorCoreAQTLayout,
13-
FpxTensorCoreLayoutType,
14-
to_scaled_tc_fpx,
15-
from_scaled_tc_fpx,
11+
from torchao.dtypes.floatx import (
12+
FloatxTensorCoreAQTLayout,
13+
FloatxTensorCoreLayoutType,
14+
to_scaled_tc_floatx,
15+
from_scaled_tc_floatx,
1616
)
17-
from torchao.dtypes.fpx.fpx import _pack_tc_fpx, _pack_tc_fp6
18-
from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32
17+
from torchao.dtypes.floatx.floatx import _pack_tc_floatx, _pack_tc_fp6
18+
from torchao.prototype.custom_fp_utils import _f32_to_floatx_unpacked, _floatx_unpacked_to_f32
1919
from torchao.quantization import (
2020
quantize_,
2121
fpx_weight_only,
@@ -25,71 +25,71 @@
2525

2626

2727
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
28-
_FPx_DTYPES = [(3, 2), (2, 2)]
28+
_Floatx_DTYPES = [(3, 2), (2, 2)]
2929

3030

31-
class TestFpxTensorCoreAQTLayout(TestCase):
31+
class TestFloatxTensorCoreAQTLayout(TestCase):
3232
@parametrize("device", _DEVICES)
3333
def test_pack_tc_fp6_correctness(self, device):
3434
x = torch.randint(256, size=(256, 64), dtype=torch.uint8, device=device)
3535

36-
expected = _pack_tc_fpx(x, 6)
36+
expected = _pack_tc_floatx(x, 6)
3737
actual = _pack_tc_fp6(x)
3838
torch.testing.assert_close(actual, expected)
3939

40-
@parametrize("ebits,mbits", _FPx_DTYPES)
40+
@parametrize("ebits,mbits", _Floatx_DTYPES)
4141
@parametrize("device", _DEVICES)
42-
def test_to_scaled_tc_fpx_compile(self, ebits, mbits, device):
42+
def test_to_scaled_tc_floatx_compile(self, ebits, mbits, device):
4343
x = torch.randn(256, 64, device=device)
4444

45-
expected = to_scaled_tc_fpx(x, ebits, mbits)
46-
actual = torch.compile(to_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits)
45+
expected = to_scaled_tc_floatx(x, ebits, mbits)
46+
actual = torch.compile(to_scaled_tc_floatx, fullgraph=True)(x, ebits, mbits)
4747
torch.testing.assert_close(actual, expected)
4848

49-
@parametrize("ebits,mbits", _FPx_DTYPES)
49+
@parametrize("ebits,mbits", _Floatx_DTYPES)
5050
@parametrize("device", _DEVICES)
51-
def test_from_tc_fpx_correctness(self, ebits, mbits, device):
51+
def test_from_tc_floatx_correctness(self, ebits, mbits, device):
5252
x = torch.randn(256, 64, device=device) * 100
5353

54-
# quantize and dequantize so that the values are exactly representable in FPx
55-
x = _fpx_unpacked_to_f32(_f32_to_fpx_unpacked(x, ebits, mbits), ebits, mbits)
54+
# quantize and dequantize so that the values are exactly representable in Floatx
55+
x = _floatx_unpacked_to_f32(_f32_to_floatx_unpacked(x, ebits, mbits), ebits, mbits)
5656

57-
tc_fpx, scale = to_scaled_tc_fpx(x, ebits, mbits)
58-
actual = from_scaled_tc_fpx(tc_fpx, ebits, mbits, scale=scale)
57+
tc_floatx, scale = to_scaled_tc_floatx(x, ebits, mbits)
58+
actual = from_scaled_tc_floatx(tc_floatx, ebits, mbits, scale=scale)
5959
torch.testing.assert_close(actual, x)
6060

61-
@parametrize("ebits,mbits", _FPx_DTYPES)
61+
@parametrize("ebits,mbits", _Floatx_DTYPES)
6262
@parametrize("device", _DEVICES)
63-
def test_from_scaled_tc_fpx_compile(self, ebits, mbits, device):
63+
def test_from_scaled_tc_floatx_compile(self, ebits, mbits, device):
6464
M, N = 256, 64
6565
nbits = 1 + ebits + mbits
6666
x = torch.randint(256, size=(M, N // 8 * nbits), dtype=torch.uint8, device=device)
6767
scale = torch.randn(M, device=device)
6868

69-
expected = from_scaled_tc_fpx(x, ebits, mbits, scale)
70-
actual = torch.compile(from_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits, scale)
69+
expected = from_scaled_tc_floatx(x, ebits, mbits, scale)
70+
actual = torch.compile(from_scaled_tc_floatx, fullgraph=True)(x, ebits, mbits, scale)
7171
torch.testing.assert_close(actual, expected)
7272

7373
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
74-
@parametrize("ebits,mbits", _FPx_DTYPES)
74+
@parametrize("ebits,mbits", _Floatx_DTYPES)
7575
def test_to_copy_device(self, ebits, mbits):
7676
from torchao.quantization.quant_primitives import (
77-
choose_qparams_affine_fpx,
78-
quantize_affine_fpx,
77+
choose_qparams_affine_floatx,
78+
quantize_affine_floatx,
7979
)
8080

8181
x = torch.randn(256, 64)
82-
scale = choose_qparams_affine_fpx(x, ebits, mbits)
83-
x = quantize_affine_fpx(x, scale, ebits, mbits)
84-
layout_type = FpxTensorCoreLayoutType(ebits, mbits)
85-
fpx_layout_tensor = FpxTensorCoreAQTLayout.from_plain(x, scale, None, layout_type).cuda()
86-
assert fpx_layout_tensor.device.type == "cuda"
87-
fpx_layout_tensor = fpx_layout_tensor.cpu()
88-
assert fpx_layout_tensor.device.type == "cpu"
82+
scale = choose_qparams_affine_floatx(x, ebits, mbits)
83+
x = quantize_affine_floatx(x, scale, ebits, mbits)
84+
layout_type = FloatxTensorCoreLayoutType(ebits, mbits)
85+
floatx_layout_tensor = FloatxTensorCoreAQTLayout.from_plain(x, scale, None, layout_type).cuda()
86+
assert floatx_layout_tensor.device.type == "cuda"
87+
floatx_layout_tensor = floatx_layout_tensor.cpu()
88+
assert floatx_layout_tensor.device.type == "cpu"
8989

9090
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
9191
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="quantization only works with torch.compile for 2.5+")
92-
@parametrize("ebits,mbits", _FPx_DTYPES)
92+
@parametrize("ebits,mbits", _Floatx_DTYPES)
9393
@parametrize("bias", [False, True])
9494
def test_fpx_weight_only(self, ebits, mbits, bias):
9595
N, OC, IC = 4, 256, 64
@@ -106,7 +106,7 @@ def test_fpx_weight_only(self, ebits, mbits, bias):
106106
torch.testing.assert_close(actual, expected)
107107

108108

109-
instantiate_parametrized_tests(TestFpxTensorCoreAQTLayout)
109+
instantiate_parametrized_tests(TestFloatxTensorCoreAQTLayout)
110110

111111

112112
if __name__ == "__main__":

test/dtypes/test_uintx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66

7-
from torchao.dtypes.uintx.Uintx import to_uintx
7+
from torchao.dtypes.uintx.uintx import to_uintx
88
from torchao.quantization.quant_api import quantize_, uintx_weight_only
99
from torchao.utils import (
1010
TORCH_VERSION_AT_LEAST_2_3,

test/test_ops.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
)
1212
from torch.testing._internal.optests import opcheck
1313
from torchao.utils import is_fbcode, TORCH_VERSION_AT_LEAST_2_5, compute_max_diff
14-
from torchao.dtypes.fpx import from_scaled_tc_fpx
14+
from torchao.dtypes.floatx import from_scaled_tc_floatx
1515
from torchao.sparsity.marlin import marlin_24_workspace, pack_to_marlin_24, inject_24
1616
import pytest
1717

@@ -33,13 +33,13 @@
3333

3434

3535
class TestOps(TestCase):
36-
def _create_fpx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device):
36+
def _create_floatx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device):
3737
# Randomly initialize each byte
3838
nbits = 1 + ebits + mbits
39-
fpx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8)
39+
floatx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8)
4040
scale = torch.rand(OC).half() + 0.5
4141
fp16_act = torch.rand(BS, IC).half() + 0.5
42-
return fpx_weight.to(device), scale.to(device), fp16_act.to(device)
42+
return floatx_weight.to(device), scale.to(device), fp16_act.to(device)
4343

4444
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
4545
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
@@ -48,28 +48,28 @@ def test_quant_llm_linear(self, ebits, mbits):
4848
OC = 256
4949
IC = 256
5050
splitK = 1
51-
fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda")
51+
floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda")
5252

5353
# smoke test
54-
torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK)
54+
torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK)
5555

5656
# comprehensive testing
5757
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
58-
opcheck(torch.ops.torchao.quant_llm_linear, (ebits, mbits, fp16_act, fpx_weight, scale, splitK), test_utils=test_utils)
58+
opcheck(torch.ops.torchao.quant_llm_linear, (ebits, mbits, fp16_act, floatx_weight, scale, splitK), test_utils=test_utils)
5959

6060
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
6161
@parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
6262
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
6363
def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK):
6464
# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/tests/python/kernel_test_fpx.py
65-
fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda")
65+
floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda")
6666

67-
results_fpx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK)
67+
results_floatx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK)
6868

69-
fp16_weight = from_scaled_tc_fpx(fpx_weight, ebits, mbits, scale).half()
69+
fp16_weight = from_scaled_tc_floatx(floatx_weight, ebits, mbits, scale).half()
7070
results_fp16 = fp16_act @ fp16_weight.T
7171

72-
error = (results_fpx - results_fp16).abs().mean()
72+
error = (results_floatx - results_fp16).abs().mean()
7373
gt = results_fp16.abs().mean()
7474
relative_error = error / gt
7575
assert relative_error < 1e-3
@@ -319,7 +319,7 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
319319
MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
320320

321321
MARLIN_TEST_PARAMS = list(itertools.product(
322-
MARLIN_24_BATCH_SIZE, MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS,
322+
MARLIN_24_BATCH_SIZE, MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS,
323323
MARLIN_24_SUPPORTED_NUM_BITS, MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS
324324
))
325325

@@ -405,7 +405,7 @@ def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_facto
405405
workspace_24 = marlin_24_workspace(size_n)
406406

407407
fn_inputs = (
408-
input_2d, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24,
408+
input_2d, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24,
409409
num_bits, a_input_in, marlin_24_scale.shape[1], a_input_out,
410410
)
411411
output = torchao.ops.marlin_24_gemm(*fn_inputs)

0 commit comments

Comments
 (0)