|
68 | 68 | )
|
69 | 69 | from torchao.quantization.autoquant import (
|
70 | 70 | AQInt8DynamicallyQuantizedLinearWeight,
|
71 |
| - AQWeightOnlyQuantizedLinearWeight, |
72 |
| - AQWeightOnlyQuantizedLinearWeight2, |
73 |
| - AQWeightOnlyQuantizedLinearWeight3, |
| 71 | + AQInt8WeightOnlyQuantizedLinearWeight, |
| 72 | + AQInt8WeightOnlyQuantizedLinearWeight2, |
| 73 | + AQInt8WeightOnlyQuantizedLinearWeight3, |
74 | 74 | AutoQuantizableLinearWeight,
|
75 | 75 |
|
76 | 76 | )
|
@@ -727,21 +727,21 @@ def test_aq_int8_dynamic_quant_subclass(self, device, dtype):
|
727 | 727 | )
|
728 | 728 | def test_aq_int8_weight_only_quant_subclass(self, device, dtype):
|
729 | 729 | self._test_lin_weight_subclass_impl(
|
730 |
| - AQWeightOnlyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype |
| 730 | + AQInt8WeightOnlyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype |
731 | 731 | )
|
732 | 732 |
|
733 | 733 | @parameterized.expand(COMMON_DEVICE_DTYPE)
|
734 | 734 | @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
|
735 | 735 | def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype):
|
736 | 736 | self._test_lin_weight_subclass_impl(
|
737 |
| - AQWeightOnlyQuantizedLinearWeight2.from_float, device, 35, test_dtype=dtype |
| 737 | + AQInt8WeightOnlyQuantizedLinearWeight2.from_float, device, 35, test_dtype=dtype |
738 | 738 | )
|
739 | 739 |
|
740 | 740 | @parameterized.expand(COMMON_DEVICE_DTYPE)
|
741 | 741 | @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
|
742 | 742 | def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype):
|
743 | 743 | self._test_lin_weight_subclass_impl(
|
744 |
| - AQWeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype |
| 744 | + AQInt8WeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype |
745 | 745 | )
|
746 | 746 |
|
747 | 747 | @parameterized.expand(COMMON_DEVICE_DTYPE)
|
@@ -1498,10 +1498,10 @@ def test_get_model_size_autoquant(self, device, dtype):
|
1498 | 1498 | size = torchao.utils.get_model_size_in_bytes(model)
|
1499 | 1499 |
|
1500 | 1500 | from torchao.quantization.autoquant import (
|
1501 |
| - AQWeightOnlyQuantizedLinearWeight2, |
| 1501 | + AQInt8WeightOnlyQuantizedLinearWeight2, |
1502 | 1502 | )
|
1503 | 1503 | qtensor_class_list = (
|
1504 |
| - AQWeightOnlyQuantizedLinearWeight2, |
| 1504 | + AQInt8WeightOnlyQuantizedLinearWeight2, |
1505 | 1505 | )
|
1506 | 1506 | mod = torchao.autoquant(torch.compile(model), qtensor_class_list = qtensor_class_list, set_inductor_config=False)
|
1507 | 1507 | mod(example_input)
|
|
0 commit comments