Skip to content

Commit c34fe82

Browse files
Remove support for quant_llm_linear
Summary: Deleted fp6_linear.cu and rest of fp6_llm folder Modified ops.py (torchao/ops.py) and test_ops.py (test/test_ops.py) to remove quant_llm_linear calls Removed all tests/references to floatx_tensor_core_layout and FloatXTensorCoreLayout Removed all tests/references to FPXWeightOnlyConfig Tasks: Related to issue [#3516](github.com//issues/3516) ghstack-source-id: fe8afeb Pull-Request: #3520
1 parent a8fa9e5 commit c34fe82

34 files changed

+3
-2664
lines changed

benchmarks/benchmark_fp6.py

Lines changed: 0 additions & 85 deletions
This file was deleted.

benchmarks/microbenchmarks/utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
Float8DynamicActivationFloat8SemiSparseWeightConfig,
1818
Float8DynamicActivationFloat8WeightConfig,
1919
Float8WeightOnlyConfig,
20-
FPXWeightOnlyConfig,
2120
GemliteUIntXWeightOnlyConfig,
2221
Int4WeightOnlyConfig,
2322
Int8DynamicActivationInt4WeightConfig,
@@ -230,9 +229,7 @@ def string_to_config(
230229
from torchao.dtypes import MarlinSparseLayout
231230

232231
return Int4WeightOnlyConfig(layout=MarlinSparseLayout(), version=1)
233-
if "fp6" in quantization:
234-
return FPXWeightOnlyConfig(3, 2)
235-
elif "uintx" in quantization:
232+
if "uintx" in quantization:
236233
# uintx-nbits-group_size, e.g. "uintx-2-64"
237234
if "hqq" in quantization:
238235
# uintx-nbits-group_size-hqq

docs/source/api_ref_dtypes.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ Quantization techniques
3232

3333
to_affine_quantized_intx
3434
to_affine_quantized_intx_static
35-
to_affine_quantized_fpx
3635
to_affine_quantized_floatx
3736
to_affine_quantized_floatx_static
3837
to_marlinqqq_quantized_intx
@@ -51,7 +50,6 @@ Prototype
5150
Int8DynamicActInt4WeightCPULayout
5251
MarlinQQQTensor
5352
MarlinQQQLayout
54-
FloatxTensorCoreLayout
5553
UintxLayout
5654

5755
..

test/core/test_config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
Float8DynamicActivationFloat8WeightConfig,
3434
Float8DynamicActivationInt4WeightConfig,
3535
Float8WeightOnlyConfig,
36-
FPXWeightOnlyConfig,
3736
GemliteUIntXWeightOnlyConfig,
3837
Int4DynamicActivationInt4WeightConfig,
3938
Int4WeightOnlyConfig,
@@ -87,7 +86,6 @@
8786
group_size=128, # Optional, has default of 64
8887
bit_width=8, # Optional, has default of 4
8988
),
90-
FPXWeightOnlyConfig(ebits=4, mbits=8),
9189
# Sparsity configs
9290
SemiSparseWeightConfig(),
9391
BlockSparseWeightConfig(blocksize=128),

test/dtypes/test_floatx.py

Lines changed: 0 additions & 135 deletions
This file was deleted.

test/quantization/test_quant_api.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
Float8DynamicActivationFloat8WeightConfig,
4949
Float8StaticActivationFloat8WeightConfig,
5050
Float8WeightOnlyConfig,
51-
FPXWeightOnlyConfig,
5251
FqnToConfig,
5352
GemliteUIntXWeightOnlyConfig,
5453
Int4DynamicActivationInt4WeightConfig,
@@ -562,7 +561,6 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq):
562561
Int8DynamicActivationInt8WeightConfig(),
563562
Int8DynamicActivationInt4WeightConfig(),
564563
Int8WeightOnlyConfig(),
565-
FPXWeightOnlyConfig(ebits=4, mbits=3),
566564
GemliteUIntXWeightOnlyConfig(),
567565
UIntXWeightOnlyConfig(dtype=torch.uint4),
568566
],
@@ -809,7 +807,6 @@ def test_config_deprecation(self):
809807
"""
810808
from torchao.quantization import (
811809
Float8StaticActivationFloat8WeightConfig,
812-
FPXWeightOnlyConfig,
813810
GemliteUIntXWeightOnlyConfig,
814811
Int4DynamicActivationInt4WeightConfig,
815812
Int8DynamicActivationInt4WeightConfig,
@@ -822,7 +819,6 @@ def test_config_deprecation(self):
822819
# Map from deprecated API to the args needed to instantiate it
823820
deprecated_apis_to_args = {
824821
Float8StaticActivationFloat8WeightConfig: (torch.randn(3),),
825-
FPXWeightOnlyConfig: (3, 2),
826822
GemliteUIntXWeightOnlyConfig: (),
827823
Int4DynamicActivationInt4WeightConfig: (),
828824
Int8DynamicActivationInt4WeightConfig: (),

test/test_ops.py

Lines changed: 0 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from torch.testing._internal.optests import opcheck
1919

2020
import torchao
21-
from torchao.dtypes.floatx import from_scaled_tc_floatx
2221
from torchao.quantization.marlin_qqq import (
2322
marlin_qqq_workspace,
2423
pack_to_marlin_qqq,
@@ -56,72 +55,6 @@
5655

5756

5857
class TestOps(TestCase):
59-
def _create_floatx_inputs(
60-
self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device, dtype
61-
):
62-
# Randomly initialize each byte
63-
nbits = 1 + ebits + mbits
64-
floatx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8)
65-
scale = torch.rand(OC).to(dtype) + 0.5
66-
fp16_act = torch.rand(BS, IC).to(dtype) + 0.5
67-
return floatx_weight.to(device), scale.to(device), fp16_act.to(device)
68-
69-
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
70-
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
71-
@parametrize("dtype", [torch.half, torch.bfloat16])
72-
def test_quant_llm_linear(self, ebits, mbits, dtype):
73-
BS = 2
74-
OC = 256
75-
IC = 256
76-
splitK = 1
77-
floatx_weight, scale, fp16_act = self._create_floatx_inputs(
78-
ebits, mbits, BS, OC, IC, "cuda", dtype
79-
)
80-
81-
# smoke test
82-
torchao.ops.quant_llm_linear(
83-
ebits, mbits, fp16_act, floatx_weight, scale, splitK
84-
)
85-
86-
# comprehensive testing
87-
test_utils = [
88-
"test_schema",
89-
"test_autograd_registration",
90-
"test_faketensor",
91-
"test_aot_dispatch_dynamic",
92-
]
93-
opcheck(
94-
torch.ops.torchao.quant_llm_linear,
95-
(ebits, mbits, fp16_act, floatx_weight, scale, splitK),
96-
test_utils=test_utils,
97-
)
98-
99-
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
100-
@parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
101-
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
102-
@parametrize("dtype", [torch.half, torch.bfloat16])
103-
def test_quant_llm_linear_correctness(
104-
self, ebits, mbits, BS, OC, IC, splitK, dtype
105-
):
106-
# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/tests/python/kernel_test_fpx.py
107-
floatx_weight, scale, fp16_act = self._create_floatx_inputs(
108-
ebits, mbits, BS, OC, IC, "cuda", dtype
109-
)
110-
111-
results_floatx = torchao.ops.quant_llm_linear(
112-
ebits, mbits, fp16_act, floatx_weight, scale, splitK
113-
)
114-
115-
fp16_weight = from_scaled_tc_floatx(floatx_weight, ebits, mbits, scale).to(
116-
dtype
117-
)
118-
results_fp16 = fp16_act @ fp16_weight.T
119-
120-
error = (results_floatx - results_fp16).abs().mean()
121-
gt = results_fp16.abs().mean()
122-
relative_error = error / gt
123-
rtol = 1e-2 if dtype == torch.bfloat16 else 1e-3
124-
assert relative_error < rtol
12558

12659
def _scaled_dot_product_int8_op_ref(
12760
self,

0 commit comments

Comments
 (0)