1616)
1717
1818from torchao .quantization import (
19- float8_dynamic_activation_float8_weight ,
20- float8_weight_only ,
21- int4_weight_only ,
22- int8_dynamic_activation_int8_weight ,
23- int8_weight_only ,
19+ Float8DynamicActivationFloat8WeightConfig ,
20+ Float8WeightOnlyConfig ,
21+ Int4WeightOnlyConfig ,
22+ Int8DynamicActivationInt8WeightConfig ,
23+ Int8WeightOnlyConfig ,
2424)
2525from torchao .quantization .observer import PerRow , PerTensor
2626from torchao .quantization .quant_api import quantize_
4242class TestAffineQuantizedTensorParallel (DTensorTestBase ):
4343 """Basic test case for tensor subclasses"""
4444
45- QUANT_METHOD_FN = staticmethod (int8_weight_only )
45+ QUANT_METHOD_FN = staticmethod (Int8WeightOnlyConfig )
4646 QUANT_METHOD_KWARGS = {}
4747
4848 @staticmethod
@@ -133,7 +133,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
133133
134134
135135class TestInt8woAffineQuantizedTensorParallel (TestAffineQuantizedTensorParallel ):
136- QUANT_METHOD_FN = staticmethod (int8_weight_only )
136+ QUANT_METHOD_FN = staticmethod (Int8WeightOnlyConfig )
137137 COMMON_DTYPES = [torch .bfloat16 , torch .float16 , torch .float32 ]
138138
139139 @common_utils .parametrize ("dtype" , COMMON_DTYPES )
@@ -144,7 +144,7 @@ def test_tp(self, dtype):
144144
145145
146146class TestInt4woAffineQuantizedTensorParallel (TestAffineQuantizedTensorParallel ):
147- QUANT_METHOD_FN = staticmethod (int4_weight_only )
147+ QUANT_METHOD_FN = staticmethod (Int4WeightOnlyConfig )
148148 QUANT_METHOD_KWARGS = {"version" : 1 }
149149 COMMON_DTYPES = [torch .bfloat16 ]
150150
@@ -167,20 +167,20 @@ class TestGemliteLayoutTensorParallel(TestAffineQuantizedTensorParallel):
167167 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
168168 @unittest .skipIf (not has_gemlite , "gemlite not available" )
169169 def test_tp_gemlite (self , dtype ):
170- from torchao .quantization import gemlite_uintx_weight_only
170+ from torchao .quantization import GemliteUIntXWeightOnlyConfig
171171
172172 for packing_bitwidth in [32 , 8 ]:
173173 for bit_width in [4 , 8 ]:
174174 for group_size in [64 , 32 , None ] if bit_width == 4 else [None ]:
175- api = lambda : gemlite_uintx_weight_only (
175+ api = lambda : GemliteUIntXWeightOnlyConfig (
176176 group_size , bit_width , packing_bitwidth
177177 )
178178 self .QUANT_METHOD_FN = staticmethod (api )
179179 return self ._test_tp (dtype )
180180
181181
182182class TestInt8dqAffineQuantizedTensorParallel (TestAffineQuantizedTensorParallel ):
183- QUANT_METHOD_FN = staticmethod (int8_dynamic_activation_int8_weight )
183+ QUANT_METHOD_FN = staticmethod (Int8DynamicActivationInt8WeightConfig )
184184 COMMON_DTYPES = [torch .bfloat16 ]
185185
186186 @common_utils .parametrize ("dtype" , COMMON_DTYPES )
@@ -199,7 +199,7 @@ def test_tp(self, dtype):
199199if torch .cuda .is_available () and torch .cuda .get_device_capability () >= (9 , 0 ):
200200
201201 class TestFloat8woAffineQuantizedTensorParallel (TestAffineQuantizedTensorParallel ):
202- QUANT_METHOD_FN = staticmethod (float8_weight_only )
202+ QUANT_METHOD_FN = staticmethod (Float8WeightOnlyConfig )
203203 COMMON_DTYPES = [torch .bfloat16 , torch .float16 , torch .float32 ]
204204
205205 @common_utils .parametrize ("dtype" , COMMON_DTYPES )
@@ -211,7 +211,7 @@ def test_tp(self, dtype):
211211 class TestFloat8dqTensorAffineQuantizedTensorParallel (
212212 TestAffineQuantizedTensorParallel
213213 ):
214- QUANT_METHOD_FN = staticmethod (float8_dynamic_activation_float8_weight )
214+ QUANT_METHOD_FN = staticmethod (Float8DynamicActivationFloat8WeightConfig )
215215 QUANT_METHOD_KWARGS = {"granularity" : PerTensor ()}
216216 COMMON_DTYPES = [torch .bfloat16 , torch .float16 , torch .float32 ]
217217
@@ -224,7 +224,7 @@ def test_tp(self, dtype):
224224 class TestFloat8dqRowAffineQuantizedTensorParallel (
225225 TestAffineQuantizedTensorParallel
226226 ):
227- QUANT_METHOD_FN = staticmethod (float8_dynamic_activation_float8_weight )
227+ QUANT_METHOD_FN = staticmethod (Float8DynamicActivationFloat8WeightConfig )
228228 QUANT_METHOD_KWARGS = {"granularity" : PerRow ()}
229229 COMMON_DTYPES = [torch .bfloat16 ]
230230
0 commit comments