Skip to content

Commit 8376847

Browse files
committed
test fixes
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 03d01ad commit 8376847

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

test/integration/test_integration.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@
6868
)
6969
from torchao.quantization.autoquant import (
7070
AQInt8DynamicallyQuantizedLinearWeight,
71-
AQWeightOnlyQuantizedLinearWeight,
72-
AQWeightOnlyQuantizedLinearWeight2,
73-
AQWeightOnlyQuantizedLinearWeight3,
71+
AQInt8WeightOnlyQuantizedLinearWeight,
72+
AQInt8WeightOnlyQuantizedLinearWeight2,
73+
AQInt8WeightOnlyQuantizedLinearWeight3,
7474
AutoQuantizableLinearWeight,
7575

7676
)
@@ -727,21 +727,21 @@ def test_aq_int8_dynamic_quant_subclass(self, device, dtype):
727727
)
728728
def test_aq_int8_weight_only_quant_subclass(self, device, dtype):
729729
self._test_lin_weight_subclass_impl(
730-
AQWeightOnlyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
730+
AQInt8WeightOnlyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
731731
)
732732

733733
@parameterized.expand(COMMON_DEVICE_DTYPE)
734734
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
735735
def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype):
736736
self._test_lin_weight_subclass_impl(
737-
AQWeightOnlyQuantizedLinearWeight2.from_float, device, 35, test_dtype=dtype
737+
AQInt8WeightOnlyQuantizedLinearWeight2.from_float, device, 35, test_dtype=dtype
738738
)
739739

740740
@parameterized.expand(COMMON_DEVICE_DTYPE)
741741
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
742742
def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype):
743743
self._test_lin_weight_subclass_impl(
744-
AQWeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype
744+
AQInt8WeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype
745745
)
746746

747747
@parameterized.expand(COMMON_DEVICE_DTYPE)
@@ -1498,10 +1498,10 @@ def test_get_model_size_autoquant(self, device, dtype):
14981498
size = torchao.utils.get_model_size_in_bytes(model)
14991499

15001500
from torchao.quantization.autoquant import (
1501-
AQWeightOnlyQuantizedLinearWeight2,
1501+
AQInt8WeightOnlyQuantizedLinearWeight2,
15021502
)
15031503
qtensor_class_list = (
1504-
AQWeightOnlyQuantizedLinearWeight2,
1504+
AQInt8WeightOnlyQuantizedLinearWeight2,
15051505
)
15061506
mod = torchao.autoquant(torch.compile(model), qtensor_class_list = qtensor_class_list, set_inductor_config=False)
15071507
mod(example_input)

0 commit comments

Comments
 (0)