Skip to content

Aqt rename#1 Layout -> TensorImpl #1046

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/benchmark_fp6.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pandas as pd
import torch.nn.functional as F
from torchao.dtypes import to_affine_quantized_fpx
from torchao.dtypes.floatx import FloatxTensorCoreAQTLayout, FloatxTensorCoreLayoutType
from torchao.dtypes.floatx import FloatxTensorCoreAQTTensorImpl, FloatxTensorCoreLayoutType
from torchao.utils import benchmark_torch_function_in_microseconds
from tqdm import tqdm

Expand Down
14 changes: 7 additions & 7 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,18 +210,18 @@ def test_serialization(self, mode: str):

# Compare weights
if mode == "weight-only":
original_weight = original_layer.weight.layout_tensor.float8_data.to(
torch.float32
)
new_weight = new_layer.weight.layout_tensor.float8_data.to(
original_weight = original_layer.weight.tensor_impl.float8_data.to(
torch.float32
)
new_weight = new_layer.weight.tensor_impl.float8_data.to(torch.float32)
else:
original_weight = original_layer.weight.original_weight_tensor.layout_tensor.float8_data.to(
original_weight = original_layer.weight.original_weight_tensor.tensor_impl.float8_data.to(
torch.float32
)
new_weight = new_layer.weight.original_weight_tensor.layout_tensor.float8_data.to(
torch.float32
new_weight = (
new_layer.weight.original_weight_tensor.tensor_impl.float8_data.to(
torch.float32
)
)

assert torch.allclose(
Expand Down
14 changes: 7 additions & 7 deletions test/dtypes/test_floatx.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
run_tests,
)
from torchao.dtypes.floatx import (
FloatxTensorCoreAQTLayout,
FloatxTensorCoreAQTTensorImpl,
FloatxTensorCoreLayoutType,
to_scaled_tc_floatx,
from_scaled_tc_floatx,
Expand All @@ -28,7 +28,7 @@
_Floatx_DTYPES = [(3, 2), (2, 2)]


class TestFloatxTensorCoreAQTLayout(TestCase):
class TestFloatxTensorCoreAQTTensorImpl(TestCase):
@parametrize("device", _DEVICES)
def test_pack_tc_fp6_correctness(self, device):
x = torch.randint(256, size=(256, 64), dtype=torch.uint8, device=device)
Expand Down Expand Up @@ -82,10 +82,10 @@ def test_to_copy_device(self, ebits, mbits):
scale = choose_qparams_affine_floatx(x, ebits, mbits)
x = quantize_affine_floatx(x, scale, ebits, mbits)
layout_type = FloatxTensorCoreLayoutType(ebits, mbits)
floatx_layout_tensor = FloatxTensorCoreAQTLayout.from_plain(x, scale, None, layout_type).cuda()
assert floatx_layout_tensor.device.type == "cuda"
floatx_layout_tensor = floatx_layout_tensor.cpu()
assert floatx_layout_tensor.device.type == "cpu"
floatx_tensor_impl = FloatxTensorCoreAQTTensorImpl.from_plain(x, scale, None, layout_type).cuda()
assert floatx_tensor_impl.device.type == "cuda"
floatx_tensor_impl = floatx_tensor_impl.cpu()
assert floatx_tensor_impl.device.type == "cpu"

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="quantization only works with torch.compile for 2.5+")
Expand All @@ -106,7 +106,7 @@ def test_fpx_weight_only(self, ebits, mbits, bias):
torch.testing.assert_close(actual, expected)


instantiate_parametrized_tests(TestFloatxTensorCoreAQTLayout)
instantiate_parametrized_tests(TestFloatxTensorCoreAQTTensorImpl)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions test/hqq/test_hqq_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from torchao.dtypes.affine_quantized_tensor import (
to_affine_quantized_intx,
ZeroPointDomain,
PlainAQTLayout,
PlainAQTTensorImpl,
PlainLayoutType,
TensorCoreTiledAQTLayout,
TensorCoreTiledAQTTensorImpl,
TensorCoreTiledLayoutType,
MappingType,
)
Expand Down
2 changes: 1 addition & 1 deletion test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,7 @@ def forward(self, x):
self.assertTrue(torch.equal(ref_q, test))

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(is_fbcode(), "'PlainAQTLayout' object has no attribute 'int_data'")
@unittest.skipIf(is_fbcode(), "'PlainAQTTensorImpl' object has no attribute 'int_data'")
@torch.no_grad()
def test_save_load_dqtensors(self, device, dtype):
if device == "cpu":
Expand Down
4 changes: 2 additions & 2 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
SemiSparseLayoutType,
TensorCoreTiledLayoutType,
Float8LayoutType,
Float8AQTLayout,
Float8AQTTensorImpl,
MarlinSparseLayoutType,
)

Expand All @@ -33,6 +33,6 @@
"SemiSparseLayoutType",
"TensorCoreTiledLayoutType",
"Float8LayoutType",
"Float8AQTLayout",
"Float8AQTTensorImpl",
"MarlinSparseLayoutType",
]
Loading
Loading