3636 _replace_with_custom_fn_if_matches_filter ,
3737 Quantizer ,
3838 TwoStepQuantizer ,
39- int8da_int4w ,
40- int4wo ,
41- int8wo ,
42- int8da_int8w ,
39+ int8_dynamic_activation_int4_weight ,
40+ int4_weight_only ,
41+ int8_weight_only ,
42+ int8_dynamic_activation_int8_weight ,
4343)
4444from torchao .utils import (
4545 TORCH_VERSION_AFTER_2_3 ,
@@ -89,7 +89,7 @@ def convert(self, model: torch.nn.Module) -> torch.nn.Module:
8989
9090class TorchCompileDynamicQuantizer (Quantizer ):
9191 def quantize (self , model : torch .nn .Module ) -> torch .nn .Module :
92- quantize (model , int8da_int8w ())
92+ quantize (model , int8_dynamic_activation_int8_weight ())
9393 return model
9494
9595class ToyLinearModel (torch .nn .Module ):
@@ -152,7 +152,7 @@ class TestQuantFlow(TestCase):
152152 def test_dynamic_quant_gpu_singleline (self ):
153153 m = ToyLinearModel ().eval ()
154154 example_inputs = m .example_inputs ()
155- m = quantize (m , int8da_int8w ())
155+ m = quantize (m , int8_dynamic_activation_int8_weight ())
156156 quantized = m (* example_inputs )
157157 # AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
158158 # While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
@@ -195,7 +195,7 @@ def test_int8_wo_quant_save_load(self):
195195 )
196196 m = ToyLinearModel ().eval ().cpu ()
197197 def api (model ):
198- model = quantize (model , int8wo ())
198+ model = quantize (model , int8_weight_only ())
199199 unwrap_tensor_subclass (model )
200200
201201 api (m )
@@ -335,7 +335,7 @@ def test_8da4w_quantizer_eval(self):
335335 )
336336
337337 @unittest .skip ("skipping until we get checkpoints for gpt-fast" )
338- def test_gptq_quantizer_int4wo (self ):
338+ def test_gptq_quantizer_int4_weight_only (self ):
339339 from torchao .quantization .GPTQ import Int4WeightOnlyGPTQQuantizer
340340 from torchao ._models ._eval import InputRecorder , TransformerEvalWrapper
341341 torchao ._models .llama .model .use_index_put_for_kv_cache = True
@@ -397,7 +397,7 @@ def test_gptq_quantizer_int4wo(self):
397397 )
398398
399399 @unittest .skip ("skipping until we get checkpoints for gpt-fast" )
400- def test_quantizer_int4wo (self ):
400+ def test_quantizer_int4_weight_only (self ):
401401 from torchao .quantization .GPTQ import Int4WeightOnlyQuantizer
402402 from torchao ._models ._eval import TransformerEvalWrapper
403403 precision = torch .bfloat16
@@ -499,11 +499,11 @@ def test_eval_wrapper_llama3(self):
499499 # TODO: move to a separate test file
500500 @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "Test only enabled for 2.4+" )
501501 def test_quantized_tensor_subclass_8da4w (self ):
502- groupsize = 32
502+ group_size = 32
503503 m = ToyLinearModel ().eval ()
504504 m_copy = copy .deepcopy (m )
505505 example_inputs = m .example_inputs ()
506- m = quantize (m , int8da_int4w ( groupsize = groupsize ))
506+ m = quantize (m , int8_dynamic_activation_int4_weight ( group_size = group_size ))
507507
508508 assert isinstance (m .linear1 .weight , LinearActQuantizedTensor )
509509 assert isinstance (m .linear2 .weight , LinearActQuantizedTensor )
@@ -514,7 +514,7 @@ def test_quantized_tensor_subclass_8da4w(self):
514514 from torchao .quantization .quant_api import Int8DynActInt4WeightQuantizer
515515 from torchao .quantization .GPTQ import Int8DynActInt4WeightLinear
516516
517- quantizer = Int8DynActInt4WeightQuantizer (groupsize = groupsize )
517+ quantizer = Int8DynActInt4WeightQuantizer (groupsize = group_size )
518518 m_copy = quantizer .quantize (m_copy )
519519 assert isinstance (m_copy .linear1 , Int8DynActInt4WeightLinear )
520520 assert isinstance (m_copy .linear2 , Int8DynActInt4WeightLinear )
@@ -531,13 +531,13 @@ def test_quantized_tensor_subclass_int4(self):
531531 m_copy = copy .deepcopy (m )
532532 example_inputs = m .example_inputs (dtype = torch .bfloat16 , device = "cuda" )
533533
534- groupsize = 32
535- m = quantize (m , int4wo ( groupsize = groupsize ))
534+ group_size = 32
535+ m = quantize (m , int4_weight_only ( group_size = group_size ))
536536 assert isinstance (m .linear1 .weight , AffineQuantizedTensor )
537537 assert isinstance (m .linear2 .weight , AffineQuantizedTensor )
538538
539539 # reference
540- _ref_change_linear_weights_to_int4_woqtensors (m_copy , groupsize = groupsize )
540+ _ref_change_linear_weights_to_int4_woqtensors (m_copy , groupsize = group_size )
541541
542542 res = m (* example_inputs )
543543 ref = m_copy (* example_inputs )
@@ -552,7 +552,7 @@ def test_quantized_tensor_subclass_int8_wo(self):
552552 m_copy = copy .deepcopy (m )
553553 example_inputs = tuple (map (lambda x : x .to (torch .bfloat16 ), m .example_inputs ()))
554554
555- m = quantize (m , int8wo ())
555+ m = quantize (m , int8_weight_only ())
556556
557557 assert isinstance (m .linear1 .weight , AffineQuantizedTensor )
558558 assert isinstance (m .linear2 .weight , AffineQuantizedTensor )
@@ -575,7 +575,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
575575 m_copy = copy .deepcopy (m )
576576 # setting batch_size to 20 to be compatible with the kernel
577577 example_inputs = m .example_inputs (batch_size = 20 , dtype = torch .bfloat16 , device = "cuda" )
578- m = quantize (m , int8da_int8w ())
578+ m = quantize (m , int8_dynamic_activation_int8_weight ())
579579
580580 assert isinstance (m .linear1 .weight , LinearActQuantizedTensor )
581581 assert isinstance (m .linear2 .weight , LinearActQuantizedTensor )
@@ -602,29 +602,14 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
602602 # make sure it compiles
603603 torch ._export .aot_compile (m_unwrapped , example_inputs )
604604
605- def test_register_apply_tensor_subclass (self ):
606- from torchao import register_apply_tensor_subclass
607- def apply_my_dtype (weight ):
608- return weight * 2
609-
610- m = ToyLinearModel ().eval ()
611- example_inputs = m .example_inputs ()
612- with self .assertRaisesRegex (ValueError , "not supported" ):
613- quantize (m , "my_dtype" )
614-
615- register_apply_tensor_subclass ("my_dtype" , apply_my_dtype )
616- # make sure it runs
617- quantize (m , "my_dtype" )
618- m (* example_inputs )
619-
620605 @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "Test only enabled for 2.4+" )
621606 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
622607 def test_quantized_tensor_subclass_save_load (self ):
623608 m = ToyLinearModel ().eval ().to (torch .bfloat16 )
624609 m_copy = copy .deepcopy (m )
625610 example_inputs = m .example_inputs (dtype = torch .bfloat16 )
626611
627- m = quantize (m , " int8_weight_only" )
612+ m = quantize (m , int8_weight_only () )
628613 ref = m (* example_inputs )
629614 with tempfile .NamedTemporaryFile () as f :
630615 torch .save (m .state_dict (), f )
0 commit comments