11
11
import torch
12
12
import torch .nn as nn
13
13
from torch ._inductor .utils import run_and_get_code
14
-
14
+ from torch . _dynamo import config
15
15
from torch .ao .quantization import MinMaxObserver , QConfigMapping
16
16
17
17
from torchao .quantization .dynamic_quant import (
21
21
apply_dynamic_quant ,
22
22
apply_weight_only_int8_quant ,
23
23
change_linear_weights_to_dqtensors ,
24
- change_linear_weights_to_woqtensors ,
24
+ change_linear_weights_to_int8woqtensors ,
25
+ change_linear_weights_to_int4woqtensors ,
25
26
_replace_with_custom_fn_if_matches_filter ,
26
27
)
27
28
from torchao .quantization .quant_primitives import (
42
43
swap_linear_with_smooth_fq_linear ,
43
44
)
44
45
from torchao .quantization .subclass import (
45
- DynamicallyQuantizedLinearWeight ,
46
- WeightOnlyQuantizedLinearWeight
46
+ Int8DynamicallyQuantizedLinearWeight ,
47
+ Int8WeightOnlyQuantizedLinearWeight ,
48
+ Int4WeightOnlyQuantizedLinearWeight
47
49
)
48
50
from torchao .quantization .utils import (
49
51
apply_logging_hook ,
59
61
import os
60
62
61
63
torch .manual_seed (0 )
64
+ config .cache_size_limit = 100
62
65
63
66
64
67
class SmoothquantUnitTest (unittest .TestCase ):
@@ -788,62 +791,108 @@ def test_qlinear_per_channel_numerics_cuda(self):
788
791
789
792
790
793
class TestSubclass (unittest .TestCase ):
794
+ def _test_dequantize_impl (
795
+ self ,
796
+ test_subclass_from_float ,
797
+ min_sqnr = 35 ,
798
+ test_dtype = torch .bfloat16 ,
799
+ test_shape = [32 , 64 , 64 ],
800
+ ):
801
+ m , k , n = test_shape
802
+ lin = torch .nn .Linear (k , n , device = "cuda" ).to (test_dtype )
803
+ w = lin .weight .detach ()
804
+ lin .weight = torch .nn .Parameter (
805
+ test_subclass_from_float (lin .weight ), requires_grad = False
806
+ )
807
+ self .assertGreater (SQNR (w , lin .weight .dequantize ()), min_sqnr , f"{ lin .weight .__class__ .__name__ } failed dtype={ test_dtype } " )
808
+ self .assertGreater (SQNR (w .t (), lin .weight .t ().dequantize ()), min_sqnr , f"{ lin .weight .__class__ .__name__ } failed transpose on dtype={ test_dtype } " )
809
+
810
+ def test_dequantize_int8_dynamic_quant_subclass (self ):
811
+ for test_dtype in [torch .float32 , torch .float16 , torch .bfloat16 ]:
812
+ self ._test_dequantize_impl (Int8DynamicallyQuantizedLinearWeight .from_float , 35 , test_dtype )
813
+
814
+ def test_dequantize_int8_weight_only_quant_subclass (self ):
815
+ for test_dtype in [torch .float32 , torch .float16 , torch .bfloat16 ]:
816
+ self ._test_dequantize_impl (Int8WeightOnlyQuantizedLinearWeight .from_float , 35 , test_dtype )
817
+
818
+ def test_dequantize_int4_weight_only_quant_subclass (self ):
819
+ self ._test_dequantize_impl (Int4WeightOnlyQuantizedLinearWeight .from_float , 15 , test_shape = [1 , 1024 , 8 ])
820
+ for groupsize in [256 , 128 ]:
821
+ for inner_k_tiles in [8 , 2 ]:
822
+ for m in [1 , 256 ]:
823
+ self ._test_dequantize_impl (lambda w : Int4WeightOnlyQuantizedLinearWeight .from_float (w , groupsize , inner_k_tiles ), 15 , test_shape = [m , 256 , 8 ])
824
+
791
825
def _test_lin_weight_subclass_impl (self ,
792
- test_subclass ,
826
+ test_subclass_from_float ,
793
827
min_sqnr = 35 ,
794
- test_dtypes = [ torch .float32 , torch . float16 , torch . bfloat16 ] ,
795
- test_shape = [32 , 64 , 32 ]
828
+ test_dtype = torch .bfloat16 ,
829
+ test_shape = [32 , 64 , 32 ],
796
830
):
797
- for test_dtype in test_dtypes :
798
- m , k , n = test_shape
799
- x = torch .randn (m , k , device = "cuda" , dtype = test_dtype )
800
- lin = torch .nn .Linear (k , n , device = "cuda" ).to (test_dtype )
801
- ref_f = lin (x )
802
-
803
- lin .weight = torch .nn .Parameter (
804
- test_subclass .from_float (lin .weight ), requires_grad = False
805
- )
806
- test = lin (x )
807
- self .assertGreater (SQNR (ref_f , test ), min_sqnr , f"{ test_subclass .__name__ } failed, no compile, dtype={ test_dtype } , (m, k, n)={ test_shape } " )
808
- lin_comp = torch .compile (lin , mode = 'max-autotune' )
809
- test_comp = lin_comp (x )
810
- self .assertGreater (SQNR (ref_f , test_comp ), min_sqnr , f"{ test_subclass .__name__ } failed at compile with dtype={ test_dtype } , (m, k, n)={ test_shape } " )
831
+ m , k , n = test_shape
832
+ x = torch .randn (m , k , device = "cuda" , dtype = test_dtype )
833
+ lin = torch .nn .Linear (k , n , device = "cuda" ).to (test_dtype )
834
+ ref_f = lin (x )
835
+
836
+ lin .weight = torch .nn .Parameter (
837
+ test_subclass_from_float (lin .weight ), requires_grad = False
838
+ )
839
+ test = lin (x )
840
+ self .assertGreater (SQNR (ref_f , test ), min_sqnr , f"{ lin .weight .__class__ .__name__ } failed, no compile, dtype={ test_dtype } , (m, k, n)={ test_shape } " )
841
+ lin_comp = torch .compile (lin , mode = 'max-autotune' )
842
+ test_comp = lin_comp (x )
843
+ self .assertGreater (SQNR (ref_f , test_comp ), min_sqnr , f"{ lin .weight .__class__ .__name__ } failed at compile with dtype={ test_dtype } , (m, k, n)={ test_shape } " )
811
844
812
845
def test_int8_dynamic_quant_subclass (self ):
813
- self ._test_lin_weight_subclass_impl (DynamicallyQuantizedLinearWeight , 35 )
846
+ for test_dtype in [torch .float32 , torch .float16 , torch .bfloat16 ]:
847
+ self ._test_lin_weight_subclass_impl (Int8DynamicallyQuantizedLinearWeight .from_float , 35 , test_dtype )
814
848
815
849
def test_int8_weight_only_quant_subclass (self ):
816
- self ._test_lin_weight_subclass_impl (WeightOnlyQuantizedLinearWeight , 40 )
850
+ for test_dtype in [torch .float32 , torch .float16 , torch .bfloat16 ]:
851
+ self ._test_lin_weight_subclass_impl (Int8WeightOnlyQuantizedLinearWeight .from_float , 40 , test_dtype )
852
+
853
+ def test_int4_weight_only_quant_subclass (self ):
854
+ self ._test_lin_weight_subclass_impl (Int4WeightOnlyQuantizedLinearWeight .from_float , 10 , test_shape = [1 , 1024 , 8 ])
855
+ for groupsize in [128 , 64 ]:
856
+ for inner_k_tiles in [4 , 2 ]:
857
+ for m in [1 , 256 ]:
858
+ self ._test_lin_weight_subclass_impl (lambda w : Int4WeightOnlyQuantizedLinearWeight .from_float (w , groupsize , inner_k_tiles ), 10 , test_shape = [m , 256 , 8 ])
817
859
818
860
@torch .no_grad ()
819
861
def _test_lin_weight_subclass_api_impl (
820
862
self ,
821
863
api ,
822
864
min_sqnr = 35 ,
823
- test_dtypes = [ torch .float32 , torch . float16 , torch . bfloat16 ] ,
865
+ test_dtype = torch .bfloat16 ,
824
866
test_shape = [32 , 64 , 32 ]
825
867
):
826
- for test_dtype in test_dtypes :
827
- m , k , n = test_shape
828
- x = torch .randn (m , k , device = "cuda" , dtype = test_dtype )
829
- mod = nn .Sequential (
830
- nn .Linear (k , n , device = "cuda" ), nn .ReLU (), nn .Linear (n , n , device = "cuda" )
831
- ).to (test_dtype )
832
- ref_f = mod (x )
833
- api (mod )
834
- test = mod (x )
835
- self .assertGreater (SQNR (ref_f , test ), min_sqnr , f"{ api .__name__ } failed, no compile dtype={ test_dtype } , (m, k, n)={ test_shape } " )
836
-
837
- mod_qc = torch .compile (mod , mode = "max-autotune" )
838
- test_comp = mod_qc (x )
839
- self .assertGreater (SQNR (ref_f , test_comp ), min_sqnr , f"{ api .__name__ } failed when compiled with dtype={ test_dtype } , (m, k, n)={ test_shape } " )
868
+ m , k , n = test_shape
869
+ x = torch .randn (m , k , device = "cuda" , dtype = test_dtype )
870
+ mod = nn .Sequential (
871
+ nn .Linear (k , n , device = "cuda" ), nn .ReLU (), nn .Linear (n , n , device = "cuda" )
872
+ ).to (test_dtype )
873
+ ref_f = mod (x )
874
+ api (mod )
875
+ test = mod (x )
876
+ self .assertGreater (SQNR (ref_f , test ), min_sqnr , f"{ api .__name__ } failed, no compile dtype={ test_dtype } , (m, k, n)={ test_shape } " )
877
+ mod_qc = torch .compile (mod , mode = "max-autotune" )
878
+ test_comp = mod_qc (x )
879
+ self .assertGreater (SQNR (ref_f , test_comp ), min_sqnr , f"{ api .__name__ } failed when compiled with dtype={ test_dtype } , (m, k, n)={ test_shape } " )
840
880
841
881
842
882
def test_int8_dynamic_quant_subclass_api (self ):
843
- self ._test_lin_weight_subclass_api_impl (change_linear_weights_to_dqtensors , 35 )
883
+ for test_dtype in [torch .float32 , torch .float16 , torch .bfloat16 ]:
884
+ self ._test_lin_weight_subclass_api_impl (change_linear_weights_to_dqtensors , 35 )
844
885
845
886
def test_int8_weight_only_quant_subclass_api (self ):
846
- self ._test_lin_weight_subclass_api_impl (change_linear_weights_to_woqtensors , 40 )
887
+ for test_dtype in [torch .float32 , torch .float16 , torch .bfloat16 ]:
888
+ self ._test_lin_weight_subclass_api_impl (change_linear_weights_to_int8woqtensors , 40 )
889
+
890
+ def test_int4_weight_only_quant_subclass_api (self ):
891
+ self ._test_lin_weight_subclass_api_impl (change_linear_weights_to_int4woqtensors , 15 , test_shape = [1 , 1024 , 256 ])
892
+ for groupsize in [64 , 32 ]:
893
+ for inner_k_tiles in [4 , 2 ]:
894
+ kwargs = {"groupsize" : groupsize , "inner_k_tiles" : inner_k_tiles }
895
+ self ._test_lin_weight_subclass_api_impl (lambda mod : change_linear_weights_to_int4woqtensors (mod , ** kwargs ), 15 , test_shape = [256 , 256 , 8 ])
847
896
848
897
class TestDynamicQuant (unittest .TestCase ):
849
898
def test_dynamic_quant (self ):
@@ -906,7 +955,7 @@ def test_weight_only_quant_use_mixed_mm(self):
906
955
907
956
class TestSaveLoadMeta (unittest .TestCase ):
908
957
@torch .no_grad ()
909
- def _test_handle_save_load_meta_impl (self , api ):
958
+ def _test_handle_save_load_meta_impl (self , api , min_sqnr = 35 ):
910
959
m , k , n = 32 , 64 , 32
911
960
class test_model (nn .Module ):
912
961
def __init__ (self ):
@@ -934,7 +983,7 @@ def forward(self, x):
934
983
model_qc = torch .compile (model , mode = "max-autotune" )
935
984
ref_q = model_qc (x ).detach ()
936
985
937
- assert SQNR (ref_f , ref_q ) > 35
986
+ assert SQNR (ref_f , ref_q ) > min_sqnr
938
987
939
988
# load model structure
940
989
with torch .device ('meta' ):
@@ -951,16 +1000,20 @@ def forward(self, x):
951
1000
model_qc = torch .compile (model , mode = "max-autotune" )
952
1001
test = model_qc (x ).detach ()
953
1002
954
- assert SQNR (ref_f , test ) > 35
1003
+ assert SQNR (ref_f , test ) > min_sqnr
955
1004
self .assertTrue (torch .equal (ref_q , test ))
956
1005
957
1006
@torch .no_grad ()
958
1007
def test_save_load_dqtensors (self ):
959
1008
self ._test_handle_save_load_meta_impl (change_linear_weights_to_dqtensors )
960
1009
961
1010
@torch .no_grad ()
962
- def test_save_load_woqtensors (self ):
963
- self ._test_handle_save_load_meta_impl (change_linear_weights_to_woqtensors )
1011
+ def test_save_load_int8woqtensors (self ):
1012
+ self ._test_handle_save_load_meta_impl (change_linear_weights_to_int8woqtensors )
1013
+
1014
+ @torch .no_grad ()
1015
+ def test_save_load_int4woqtensors (self ):
1016
+ self ._test_handle_save_load_meta_impl (change_linear_weights_to_int4woqtensors , 20 )
964
1017
965
1018
class TorchCompileUnitTest (unittest .TestCase ):
966
1019
def test_fullgraph (self ):
0 commit comments