Skip to content

Commit 85648b2

Browse files
authored
Refactor Affine Quantized Tensor (#1234)
1 parent 56bf2e8 commit 85648b2

28 files changed

+2438
-2230
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def test_to_device(self, apply_quant):
9292

9393
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
9494
def test_register_new_dispatch(self):
95-
from torchao.dtypes.affine_quantized_tensor import (
95+
from torchao.dtypes.affine_quantized_tensor_ops import (
9696
register_aqt_quantized_linear_dispatch,
9797
deregister_aqt_quantized_linear_dispatch,
9898
)

test/dtypes/test_floatx.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,11 @@
1010
run_tests,
1111
)
1212
from torchao.dtypes.floatx import (
13-
FloatxTensorCoreAQTTensorImpl,
1413
FloatxTensorCoreLayout,
1514
to_scaled_tc_floatx,
1615
from_scaled_tc_floatx,
1716
)
18-
from torchao.dtypes.floatx.floatx import _pack_tc_floatx, _pack_tc_fp6
17+
from torchao.dtypes.floatx.floatx_tensor_core_layout import _pack_tc_floatx, _pack_tc_fp6, FloatxTensorCoreAQTTensorImpl
1918
from torchao.prototype.custom_fp_utils import _f32_to_floatx_unpacked, _floatx_unpacked_to_f32
2019
from torchao.quantization import (
2120
quantize_,

test/dtypes/test_uint4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from torchao.dtypes.uint4 import (
2+
from torchao.dtypes.uintx.uint4_layout import (
33
UInt4Tensor,
44
PerChannelSymmetricWeightUInt4Tensor,
55
)

test/dtypes/test_uintx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66

7-
from torchao.dtypes.uintx import to_uintx
7+
from torchao.dtypes.uintx.uintx_layout import to_uintx
88
from torchao.quantization.quant_api import quantize_, uintx_weight_only
99
from torchao.utils import (
1010
TORCH_VERSION_AT_LEAST_2_3,

test/hqq/test_hqq_affine.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
11
import unittest
22
import torch
3-
from torchao.dtypes.affine_quantized_tensor import (
4-
to_affine_quantized_intx,
3+
from torchao.quantization import (
54
ZeroPointDomain,
6-
PlainAQTTensorImpl,
7-
PlainLayout,
8-
TensorCoreTiledAQTTensorImpl,
9-
TensorCoreTiledLayout,
105
MappingType,
116
)
127

test/prototype/test_sparse_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def test_sparse(self, compile):
194194
quantize_(model_copy, int8_dynamic_activation_int8_weight())
195195
reference = model_copy(input)
196196

197-
from torchao.dtypes.affine_quantized_tensor import BlockSparseLayout
197+
from torchao.dtypes import BlockSparseLayout
198198

199199
quantize_(
200200
model,

torchao/dtypes/__init__.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,7 @@
1+
from . import affine_quantized_tensor_ops
12
from .affine_quantized_tensor import (
23
AffineQuantizedTensor,
3-
Float8AQTTensorImpl,
4-
Float8Layout,
5-
Layout,
6-
MarlinQQQLayout,
7-
MarlinSparseLayout,
8-
PlainLayout,
9-
SemiSparseLayout,
10-
TensorCoreTiledLayout,
4+
MarlinQQQTensor,
115
to_affine_quantized_floatx,
126
to_affine_quantized_floatx_static,
137
# experimental, will be merged into floatx in the future
@@ -16,15 +10,26 @@
1610
to_affine_quantized_intx_static,
1711
to_marlinqqq_quantized_intx,
1812
)
13+
from .floatx import (
14+
Float8Layout,
15+
)
1916
from .nf4tensor import NF4Tensor, to_nf4
20-
21-
# from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor
22-
from .uint4 import UInt4Tensor
17+
from .uintx import (
18+
BlockSparseLayout,
19+
MarlinQQQLayout,
20+
MarlinSparseLayout,
21+
SemiSparseLayout,
22+
TensorCoreTiledLayout,
23+
UintxLayout,
24+
)
25+
from .utils import (
26+
Layout,
27+
PlainLayout,
28+
)
2329

2430
__all__ = [
2531
"NF4Tensor",
2632
"to_nf4",
27-
"UInt4Tensor",
2833
"AffineQuantizedTensor",
2934
"to_affine_quantized_intx",
3035
"to_affine_quantized_intx_static",
@@ -37,7 +42,10 @@
3742
"SemiSparseLayout",
3843
"TensorCoreTiledLayout",
3944
"Float8Layout",
40-
"Float8AQTTensorImpl",
4145
"MarlinSparseLayout",
46+
"affine_quantized_tensor_ops",
47+
"BlockSparseLayout",
48+
"UintxLayout",
49+
"MarlinQQQTensor",
4250
"MarlinQQQLayout",
4351
]

0 commit comments

Comments
 (0)