Skip to content

Commit 9e13a36

Browse files
ai-edge-botcopybara-github
authored andcommitted
Add max_hadamard_size parameter for Hadamard rotations.
PiperOrigin-RevId: 868388507
1 parent ee1ef41 commit 9e13a36

15 files changed

+137
-31
lines changed

ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,14 @@ def _make_hadamard_matrix(size: int) -> np.ndarray:
5454
def _rotate_with_diagonal_hadamard(
5555
tensor_content: np.ndarray,
5656
axis: int,
57+
max_size: int = 0,
5758
):
5859
"""Quantizes the given float array using the diagonal Hadamard algorithm.
5960
6061
Args:
6162
tensor_content: The float array to quantize.
6263
axis: The axis of the tensor to rotate.
64+
max_size: The maximum size of the Hadamard matrix.
6365
6466
Returns:
6567
A tuple containing the quantized array and the recovered array.
@@ -77,7 +79,9 @@ def _rotate_with_diagonal_hadamard(
7779
# Use the largest power of 2 that is a factor of the dimension and then
7880
# tile this Hadamard matrix along the diagonal. 2**30 is just a large power
7981
# of 2 to calculate this factor.
80-
hadamard_size = np.gcd(tensor_content.shape[axis], 2 ** 30)
82+
hadamard_size = np.gcd(tensor_content.shape[axis], 2**30)
83+
if max_size > 0:
84+
hadamard_size = min(hadamard_size, max_size)
8185
diagonal_size = tensor_content.shape[axis] // hadamard_size
8286
# Output size is the product of all dimensions except the one being rotated.
8387
output_size = np.prod(np.delete(tensor_content.shape, axis))
@@ -135,7 +139,9 @@ def get_tensor_quant_params(
135139

136140
# Rotate the tensor with a Hadamard matrix.
137141
w_rotated, hadamard_size, random_vector = _rotate_with_diagonal_hadamard(
138-
tensor_content, axis=reduce_axis
142+
tensor_content,
143+
axis=reduce_axis,
144+
max_size=tensor_quant_config.max_hadamard_size,
139145
)
140146

141147
# Get the quantized values of the rotated tensor.

ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Test Hadamard rotation materialization."""
1717

18+
import dataclasses
1819
import os
1920

2021
from absl.testing import parameterized
@@ -227,6 +228,36 @@ def test_get_tensor_quant_params_basic(self):
227228
if qparams.hadamard is not None:
228229
self.assertEqual(qparams.hadamard.hadamard_size, 32)
229230

231+
def test_get_tensor_quant_params_max_size(self):
232+
input_tensor = self._subgraph.tensors[self._fc_op.inputs[1]]
233+
buffer = self._graph_info.buffers[self._fc_buffer_id]
234+
np_buffer = np.frombuffer(buffer.data, dtype=np.float32).reshape(
235+
input_tensor.shape
236+
)
237+
# The original dimension is 32. The largest power of 2 factor is 32.
238+
# If we set max_hadamard_size to 16, then it should be 16.
239+
new_op_quant_config = dataclasses.replace(
240+
self._op_info.op_quant_config,
241+
weight_tensor_config=qtyping.TensorQuantizationConfig(
242+
num_bits=8,
243+
symmetric=True,
244+
granularity=qtyping.QuantGranularity.CHANNELWISE,
245+
max_hadamard_size=16,
246+
),
247+
)
248+
self._op_info = dataclasses.replace(
249+
self._op_info, op_quant_config=new_op_quant_config
250+
)
251+
qparams = hadamard_rotation.get_tensor_quant_params(
252+
self._op_info,
253+
self._op_info.op_quant_config.weight_tensor_config,
254+
np_buffer,
255+
self._tensor_name_to_qsv,
256+
)
257+
self.assertIsNotNone(qparams.hadamard)
258+
if qparams.hadamard is not None:
259+
self.assertEqual(qparams.hadamard.hadamard_size, 16)
260+
230261
def test_get_tensor_quant_params_golden_1(self):
231262
test_data = np.ones((6, 6))
232263
# expected:

ai_edge_quantizer/algorithms/utils/common_utils.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -94,21 +94,36 @@ def check_if_valid_op_config(
9494
f"No policy was specified for op: {op_name} with config:"
9595
f" {op_quant_config}."
9696
)
97-
# The config_check_policy contains all possible valid configs, except for
98-
# variations in the min_weight_elements field (it's set to 0 for all of them).
99-
# min_weight_elements has to be ignored during policy check here because it
100-
# can be any non-negative integer, which means we can't list all possible
101-
# values in the policy.
102-
elif (
103-
dataclasses.replace(op_quant_config, min_weight_elements=0)
104-
not in config_check_policy[op_name]
105-
):
106-
error_msg = (
107-
f"Quantization config for op: {op_name} with config:"
108-
f" {op_quant_config} was not found in the policy."
109-
)
11097
else:
111-
check_passed = True
98+
# min_weight_elements and max_hadamard_size have to be ignored during
99+
# policy check here because they can be any non-negative integer, which
100+
# means we can't list all possible values in the policy.
101+
op_quant_config_to_check = dataclasses.replace(
102+
op_quant_config, min_weight_elements=0
103+
)
104+
if op_quant_config_to_check.weight_tensor_config is not None:
105+
op_quant_config_to_check = dataclasses.replace(
106+
op_quant_config_to_check,
107+
weight_tensor_config=dataclasses.replace(
108+
op_quant_config_to_check.weight_tensor_config, max_hadamard_size=0
109+
),
110+
)
111+
if op_quant_config_to_check.activation_tensor_config is not None:
112+
op_quant_config_to_check = dataclasses.replace(
113+
op_quant_config_to_check,
114+
activation_tensor_config=dataclasses.replace(
115+
op_quant_config_to_check.activation_tensor_config,
116+
max_hadamard_size=0,
117+
),
118+
)
119+
120+
if op_quant_config_to_check not in config_check_policy[op_name]:
121+
error_msg = (
122+
f"Quantization config for op: {op_name} with config:"
123+
f" {op_quant_config!r} was not found in the policy."
124+
)
125+
else:
126+
check_passed = True
112127

113128
if not check_passed:
114129
raise ValueError(

ai_edge_quantizer/algorithms/utils/common_utils_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,19 @@ def test_check_drq_config_with_non_default_min_weight_elements_succeeds(self):
224224
_TFLOpName.CONV_2D, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
225225
)
226226

227+
def test_check_config_with_non_default_max_hadamard_size_succeeds(self):
228+
op_quant_config = _OpQuantConfig(
229+
weight_tensor_config=_TensorQuantConfig(
230+
num_bits=8,
231+
granularity=qtyping.QuantGranularity.CHANNELWISE,
232+
max_hadamard_size=1024,
233+
),
234+
compute_precision=_ComputePrecision.INTEGER, # DRQ.
235+
)
236+
common_utils.check_if_valid_op_config(
237+
_TFLOpName.FULLY_CONNECTED, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
238+
)
239+
227240
@parameterized.product(
228241
op_name=(_TFLOpName.FULLY_CONNECTED, _TFLOpName.CONV_2D),
229242
act_num_bits=(8, 16),

ai_edge_quantizer/qtyping.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,12 +317,15 @@ class TensorQuantizationConfig:
317317
quantization.
318318
dtype: The data type of the tensor.
319319
algorithm_key: The algorithm key to use for quantization.
320+
max_hadamard_size: The maximum size of the Hadamard matrix to use for
321+
Hadamard rotation.
320322
"""
321323

322324
num_bits: int
323325
symmetric: bool = True
324326
granularity: QuantGranularity = QuantGranularity.TENSORWISE
325327
dtype: TensorDataType = TensorDataType.INT
328+
max_hadamard_size: int = 0
326329

327330
def to_dict(self) -> dict[str, Any]:
328331
"""Converts ActivationQuantizationConfig to dict."""

ai_edge_quantizer/recipe_manager_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,12 +581,14 @@ def test_get_full_quantization_config(self):
581581
'symmetric': False,
582582
'granularity': _QuantGranularity.TENSORWISE,
583583
'dtype': 'INT',
584+
'max_hadamard_size': 0,
584585
},
585586
'weight_tensor_config': {
586587
'num_bits': 8,
587588
'symmetric': True,
588589
'granularity': _QuantGranularity.TENSORWISE,
589590
'dtype': 'INT',
591+
'max_hadamard_size': 0,
590592
},
591593
# WEIGHT_ONLY.
592594
'compute_precision': _ComputePrecision.INTEGER,
@@ -605,6 +607,7 @@ def test_get_full_quantization_config(self):
605607
'num_bits': 8,
606608
'symmetric': True,
607609
'granularity': _QuantGranularity.TENSORWISE,
610+
'max_hadamard_size': 0,
608611
},
609612
# WEIGHT_ONLY.
610613
'compute_precision': _ComputePrecision.FLOAT,
@@ -623,6 +626,7 @@ def test_get_full_quantization_config(self):
623626
'num_bits': 4,
624627
'symmetric': True,
625628
'granularity': _QuantGranularity.TENSORWISE,
629+
'max_hadamard_size': 0,
626630
},
627631
# WEIGHT_ONLY.
628632
'compute_precision': _ComputePrecision.FLOAT,
@@ -641,6 +645,7 @@ def test_get_full_quantization_config(self):
641645
'num_bits': 6,
642646
'symmetric': True,
643647
'granularity': _QuantGranularity.TENSORWISE,
648+
'max_hadamard_size': 0,
644649
},
645650
# WEIGHT_ONLY.
646651
'compute_precision': _ComputePrecision.FLOAT,
@@ -659,6 +664,7 @@ def test_get_full_quantization_config(self):
659664
'num_bits': 3,
660665
'symmetric': True,
661666
'granularity': _QuantGranularity.TENSORWISE,
667+
'max_hadamard_size': 0,
662668
},
663669
# WEIGHT_ONLY.
664670
'compute_precision': _ComputePrecision.FLOAT,
@@ -924,6 +930,26 @@ def test_need_calibration_true(self):
924930
)
925931
self.assertTrue(self._recipe_manager.need_calibration())
926932

933+
def test_get_hadamard_with_max_size(self):
934+
self._recipe_manager.add_quantization_config(
935+
regex='.*/Dense/.*',
936+
operation_name=_TFLOpName.FULLY_CONNECTED,
937+
algorithm_key=_AlgorithmName.HADAMARD_ROTATION,
938+
op_config=qtyping.OpQuantizationConfig(
939+
weight_tensor_config=_TensorQuantConfig(
940+
num_bits=8, max_hadamard_size=1024
941+
),
942+
compute_precision=_ComputePrecision.INTEGER,
943+
),
944+
)
945+
alg_key, op_config = self._recipe_manager.get_quantization_configs(
946+
_TFLOpName.FULLY_CONNECTED, 'model/Dense/op'
947+
)
948+
self.assertEqual(alg_key, _AlgorithmName.HADAMARD_ROTATION)
949+
weight_tensor_config = op_config.weight_tensor_config
950+
assert weight_tensor_config is not None
951+
self.assertEqual(weight_tensor_config.max_hadamard_size, 1024)
952+
927953

928954
if __name__ == '__main__':
929955
googletest.main()

ai_edge_quantizer/recipes/default_a16w8_recipe.json

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
"num_bits": 16,
99
"symmetric": true,
1010
"granularity": "TENSORWISE",
11-
"dtype": "INT"
11+
"dtype": "INT",
12+
"max_hadamard_size": 0
1213
},
1314
"weight_tensor_config": {
1415
"num_bits": 8,
1516
"symmetric": true,
1617
"granularity": "CHANNELWISE",
17-
"dtype": "INT"
18+
"dtype": "INT",
19+
"max_hadamard_size": 0
1820
},
1921
"compute_precision": "INTEGER",
2022
"explicit_dequantize": false,

ai_edge_quantizer/recipes/default_a8w8_recipe.json

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
"num_bits": 8,
99
"symmetric": false,
1010
"granularity": "TENSORWISE",
11-
"dtype": "INT"
11+
"dtype": "INT",
12+
"max_hadamard_size": 0
1213
},
1314
"weight_tensor_config": {
1415
"num_bits": 8,
1516
"symmetric": true,
1617
"granularity": "CHANNELWISE",
17-
"dtype": "INT"
18+
"dtype": "INT",
19+
"max_hadamard_size": 0
1820
},
1921
"compute_precision": "INTEGER",
2022
"explicit_dequantize": false,

ai_edge_quantizer/recipes/default_af32w4float_recipe.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
"num_bits": 4,
99
"symmetric": false,
1010
"granularity": "CHANNELWISE",
11-
"dtype": "INT"
11+
"dtype": "INT",
12+
"max_hadamard_size": 0
1213
},
1314
"compute_precision": "FLOAT",
1415
"explicit_dequantize": true,

ai_edge_quantizer/recipes/default_af32w8float_recipe.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
"num_bits": 8,
99
"symmetric": false,
1010
"granularity": "CHANNELWISE",
11-
"dtype": "INT"
11+
"dtype": "INT",
12+
"max_hadamard_size": 0
1213
},
1314
"compute_precision": "FLOAT",
1415
"explicit_dequantize": true,

0 commit comments

Comments
 (0)