@@ -636,20 +636,20 @@ def test_get_full_quantization_config(self):
636636 expected_full_quantization_config = [
637637 {
638638 'regex' : '.*' ,
639- 'operation' : '*' ,
639+ 'operation' : _TFLOpName . ALL_SUPPORTED ,
640640 'algorithm_key' : _AlgorithmName .MIN_MAX_UNIFORM_QUANT ,
641641 'op_config' : {
642642 'activation_tensor_config' : {
643643 'num_bits' : 8 ,
644644 'symmetric' : False ,
645645 'granularity' : _QuantGranularity .TENSORWISE ,
646- 'dtype' : ' INT' ,
646+ 'dtype' : _TensorDataType . INT ,
647647 },
648648 'weight_tensor_config' : {
649649 'num_bits' : 8 ,
650650 'symmetric' : True ,
651651 'granularity' : _QuantGranularity .TENSORWISE ,
652- 'dtype' : ' INT' ,
652+ 'dtype' : _TensorDataType . INT ,
653653 },
654654 # WEIGHT_ONLY.
655655 'compute_precision' : _ComputePrecision .INTEGER ,
@@ -660,11 +660,11 @@ def test_get_full_quantization_config(self):
660660 },
661661 {
662662 'regex' : '.*' ,
663- 'operation' : ' BATCH_MATMUL' ,
663+ 'operation' : _TFLOpName . BATCH_MATMUL ,
664664 'algorithm_key' : _AlgorithmName .MIN_MAX_UNIFORM_QUANT ,
665665 'op_config' : {
666666 'weight_tensor_config' : {
667- 'dtype' : ' INT' ,
667+ 'dtype' : _TensorDataType . INT ,
668668 'num_bits' : 8 ,
669669 'symmetric' : True ,
670670 'granularity' : _QuantGranularity .TENSORWISE ,
@@ -678,11 +678,11 @@ def test_get_full_quantization_config(self):
678678 },
679679 {
680680 'regex' : '.*/Dense/.*' ,
681- 'operation' : '*' ,
681+ 'operation' : _TFLOpName . ALL_SUPPORTED ,
682682 'algorithm_key' : _AlgorithmName .MIN_MAX_UNIFORM_QUANT ,
683683 'op_config' : {
684684 'weight_tensor_config' : {
685- 'dtype' : ' INT' ,
685+ 'dtype' : _TensorDataType . INT ,
686686 'num_bits' : 4 ,
687687 'symmetric' : True ,
688688 'granularity' : _QuantGranularity .TENSORWISE ,
@@ -696,11 +696,11 @@ def test_get_full_quantization_config(self):
696696 },
697697 {
698698 'regex' : '.*/Dense_1/.*' ,
699- 'operation' : ' FULLY_CONNECTED' ,
699+ 'operation' : _TFLOpName . FULLY_CONNECTED ,
700700 'algorithm_key' : _AlgorithmName .MIN_MAX_UNIFORM_QUANT ,
701701 'op_config' : {
702702 'weight_tensor_config' : {
703- 'dtype' : ' INT' ,
703+ 'dtype' : _TensorDataType . INT ,
704704 'num_bits' : 6 ,
705705 'symmetric' : True ,
706706 'granularity' : _QuantGranularity .TENSORWISE ,
@@ -714,11 +714,11 @@ def test_get_full_quantization_config(self):
714714 },
715715 {
716716 'regex' : '.*/Dense_1/.*' ,
717- 'operation' : ' BATCH_MATMUL' ,
717+ 'operation' : _TFLOpName . BATCH_MATMUL ,
718718 'algorithm_key' : _AlgorithmName .MIN_MAX_UNIFORM_QUANT ,
719719 'op_config' : {
720720 'weight_tensor_config' : {
721- 'dtype' : ' INT' ,
721+ 'dtype' : _TensorDataType . INT ,
722722 'num_bits' : 3 ,
723723 'symmetric' : True ,
724724 'granularity' : _QuantGranularity .TENSORWISE ,
@@ -987,6 +987,28 @@ def test_need_calibration_true(self):
987987 )
988988 self .assertTrue (self ._recipe_manager .need_calibration ())
989989
990+ def test_get_hadamard_with_max_size (self ):
991+ self ._recipe_manager .add_quantization_config (
992+ regex = '.*/Dense/.*' ,
993+ operation_name = _TFLOpName .FULLY_CONNECTED ,
994+ algorithm_key = _AlgorithmName .HADAMARD_ROTATION ,
995+ op_config = qtyping .OpQuantizationConfig (
996+ weight_tensor_config = _TensorQuantConfig (
997+ num_bits = 8 , algorithm_params = {'max_hadamard_size' : 1024 }
998+ ),
999+ compute_precision = _ComputePrecision .INTEGER ,
1000+ ),
1001+ )
1002+ alg_key , op_config = self ._recipe_manager .get_quantization_configs (
1003+ _TFLOpName .FULLY_CONNECTED , 'model/Dense/op'
1004+ )
1005+ self .assertEqual (alg_key , _AlgorithmName .HADAMARD_ROTATION )
1006+ weight_tensor_config = op_config .weight_tensor_config
1007+ assert weight_tensor_config is not None
1008+ self .assertEqual (
1009+ weight_tensor_config .algorithm_params ['max_hadamard_size' ], 1024
1010+ )
1011+
9901012
9911013if __name__ == '__main__' :
9921014 absltest .main ()
0 commit comments