Skip to content

Commit 3c9d855

Browse files
ai-edge-botcopybara-github
authored andcommitted
Add max_hadamard_size parameter for Hadamard rotations.
PiperOrigin-RevId: 868388507
1 parent 0dec1ad commit 3c9d855

16 files changed

+212
-59
lines changed

ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,14 @@ def _make_hadamard_matrix(size: int) -> np.ndarray:
5454
def _rotate_with_diagonal_hadamard(
5555
tensor_content: np.ndarray,
5656
axis: int,
57+
max_size: int = 0,
5758
):
5859
"""Quantizes the given float array using the diagonal Hadamard algorithm.
5960
6061
Args:
6162
tensor_content: The float array to quantize.
6263
axis: The axis of the tensor to rotate.
64+
max_size: The maximum size of the Hadamard matrix.
6365
6466
Returns:
6567
A tuple containing the quantized array and the recovered array.
@@ -77,7 +79,9 @@ def _rotate_with_diagonal_hadamard(
7779
# Use the largest power of 2 that is a factor of the dimension and then
7880
# tile this Hadamard matrix along the diagonal. 2**30 is just a large power
7981
# of 2 to calculate this factor.
80-
hadamard_size = np.gcd(tensor_content.shape[axis], 2 ** 30)
82+
hadamard_size = np.gcd(tensor_content.shape[axis], 2**30)
83+
if max_size > 0:
84+
hadamard_size = min(hadamard_size, max_size)
8185
diagonal_size = tensor_content.shape[axis] // hadamard_size
8286
# Output size is the product of all dimensions except the one being rotated.
8387
output_size = np.prod(np.delete(tensor_content.shape, axis))
@@ -135,7 +139,9 @@ def get_tensor_quant_params(
135139

136140
# Rotate the tensor with a Hadamard matrix.
137141
w_rotated, hadamard_size, random_vector = _rotate_with_diagonal_hadamard(
138-
tensor_content, axis=reduce_axis
142+
tensor_content,
143+
axis=reduce_axis,
144+
max_size=tensor_quant_config.algorithm_params.get("max_hadamard_size", 0),
139145
)
140146

141147
# Get the quantized values of the rotated tensor.

ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Test Hadamard rotation materialization."""
1717

18+
import dataclasses
1819
import pathlib
1920

2021
from absl.testing import absltest
@@ -227,6 +228,36 @@ def test_get_tensor_quant_params_basic(self):
227228
if qparams.hadamard is not None:
228229
self.assertEqual(qparams.hadamard.hadamard_size, 32)
229230

231+
def test_get_tensor_quant_params_max_size(self):
232+
input_tensor = self._subgraph.tensors[self._fc_op.inputs[1]]
233+
buffer = self._graph_info.buffers[self._fc_buffer_id]
234+
np_buffer = np.frombuffer(buffer.data, dtype=np.float32).reshape(
235+
input_tensor.shape
236+
)
237+
# The original dimension is 32. The largest power of 2 factor is 32.
238+
# If we set algorithm_params to {'max_hadamard_size': 16}, then it should be 16.
239+
new_op_quant_config = dataclasses.replace(
240+
self._op_info.op_quant_config,
241+
weight_tensor_config=qtyping.TensorQuantizationConfig(
242+
num_bits=8,
243+
symmetric=True,
244+
granularity=qtyping.QuantGranularity.CHANNELWISE,
245+
algorithm_params={"max_hadamard_size": 16},
246+
),
247+
)
248+
self._op_info = dataclasses.replace(
249+
self._op_info, op_quant_config=new_op_quant_config
250+
)
251+
qparams = hadamard_rotation.get_tensor_quant_params(
252+
self._op_info,
253+
self._op_info.op_quant_config.weight_tensor_config,
254+
np_buffer,
255+
self._tensor_name_to_qsv,
256+
)
257+
self.assertIsNotNone(qparams.hadamard)
258+
if qparams.hadamard is not None:
259+
self.assertEqual(qparams.hadamard.hadamard_size, 16)
260+
230261
def test_get_tensor_quant_params_golden_1(self):
231262
test_data = np.ones((6, 6))
232263
# expected:

ai_edge_quantizer/algorithms/utils/common_utils.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -94,21 +94,36 @@ def check_if_valid_op_config(
9494
f"No policy was specified for op: {op_name} with config:"
9595
f" {op_quant_config}."
9696
)
97-
# The config_check_policy contains all possible valid configs, except for
98-
# variations in the min_weight_elements field (it's set to 0 for all of them).
99-
# min_weight_elements has to be ignored during policy check here because it
100-
# can be any non-negative integer, which means we can't list all possible
101-
# values in the policy.
102-
elif (
103-
dataclasses.replace(op_quant_config, min_weight_elements=0)
104-
not in config_check_policy[op_name]
105-
):
106-
error_msg = (
107-
f"Quantization config for op: {op_name} with config:"
108-
f" {op_quant_config} was not found in the policy."
109-
)
11097
else:
111-
check_passed = True
98+
# min_weight_elements and algorithm_params have to be ignored during
99+
# policy check here because they can be any non-negative integer or dict,
100+
# which means we can't list all possible values in the policy.
101+
op_quant_config_to_check = dataclasses.replace(
102+
op_quant_config, min_weight_elements=0
103+
)
104+
if op_quant_config_to_check.weight_tensor_config is not None:
105+
op_quant_config_to_check = dataclasses.replace(
106+
op_quant_config_to_check,
107+
weight_tensor_config=dataclasses.replace(
108+
op_quant_config_to_check.weight_tensor_config, algorithm_params={}
109+
),
110+
)
111+
if op_quant_config_to_check.activation_tensor_config is not None:
112+
op_quant_config_to_check = dataclasses.replace(
113+
op_quant_config_to_check,
114+
activation_tensor_config=dataclasses.replace(
115+
op_quant_config_to_check.activation_tensor_config,
116+
algorithm_params={},
117+
),
118+
)
119+
120+
if op_quant_config_to_check not in config_check_policy[op_name]:
121+
error_msg = (
122+
f"Quantization config for op: {op_name} with config:"
123+
f" {op_quant_config!r} was not found in the policy."
124+
)
125+
else:
126+
check_passed = True
112127

113128
if not check_passed:
114129
raise ValueError(

ai_edge_quantizer/algorithms/utils/common_utils_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,21 @@ def test_check_drq_config_with_non_default_min_weight_elements_succeeds(self):
224224
_TFLOpName.CONV_2D, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
225225
)
226226

227+
def test_check_config_with_non_default_algorithm_params_succeeds(self):
228+
op_quant_config = _OpQuantConfig(
229+
weight_tensor_config=_TensorQuantConfig(
230+
num_bits=8,
231+
granularity=qtyping.QuantGranularity.CHANNELWISE,
232+
algorithm_params={"max_hadamard_size": 1024},
233+
),
234+
compute_precision=_ComputePrecision.INTEGER, # DRQ.
235+
)
236+
common_utils.check_if_valid_op_config(
237+
_TFLOpName.FULLY_CONNECTED,
238+
op_quant_config,
239+
_DEFAULT_CONFIG_CHECK_POLICY,
240+
)
241+
227242
@parameterized.product(
228243
op_name=(_TFLOpName.FULLY_CONNECTED, _TFLOpName.CONV_2D),
229244
act_num_bits=(8, 16),

ai_edge_quantizer/qtyping.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,9 @@
2020
import copy
2121
import dataclasses
2222
import enum
23-
from typing import Any, Callable, Optional, Union
24-
23+
from typing import Any, Callable, Mapping, Optional, Union, TypeAlias
24+
from immutabledict import immutabledict
2525
import numpy as np
26-
from typing_extensions import TypeAlias
2726

2827

2928
QSV: TypeAlias = MutableMapping[str, Any]
@@ -317,22 +316,32 @@ class TensorQuantizationConfig:
317316
quantization.
318317
dtype: The data type of the tensor.
319318
algorithm_key: The algorithm key to use for quantization.
319+
algorithm_params: Additional parameters for the quantization algorithm.
320320
"""
321321

322322
num_bits: int
323323
symmetric: bool = True
324324
granularity: QuantGranularity = QuantGranularity.TENSORWISE
325325
dtype: TensorDataType = TensorDataType.INT
326+
algorithm_params: Mapping[str, Any] = dataclasses.field(
327+
default_factory=immutabledict
328+
)
329+
330+
def __post_init__(self):
331+
if not isinstance(self.algorithm_params, immutabledict):
332+
object.__setattr__(
333+
self, 'algorithm_params', immutabledict(self.algorithm_params)
334+
)
326335

327336
def to_dict(self) -> dict[str, Any]:
328337
"""Converts ActivationQuantizationConfig to dict."""
329338
return dataclasses.asdict(
330339
self,
331340
dict_factory=lambda x: { # pylint: disable=g-long-lambda
332-
k: v
341+
k: (dict(v) if isinstance(v, Mapping) and not isinstance(v, dict) else v)
333342
for (k, v) in x
334343
# Skip None and empty dict values.
335-
if v is not None and not (isinstance(v, dict) and not v)
344+
if v is not None and not (isinstance(v, (dict, Mapping)) and not v)
336345
},
337346
)
338347

@@ -342,6 +351,15 @@ def from_dict(cls, params: dict[str, Any]) -> 'TensorQuantizationConfig':
342351
params_copy = copy.deepcopy(params)
343352
# Process block_size config from legacy recipe.
344353
params_copy = _process_block_size(params_copy)
354+
355+
# Move any unknown fields to algorithm_params for backward compatibility.
356+
known_fields = {f.name for f in dataclasses.fields(cls)}
357+
algorithm_params = params_copy.pop('algorithm_params', {})
358+
for key in list(params_copy.keys()):
359+
if key not in known_fields:
360+
algorithm_params[key] = params_copy.pop(key)
361+
params_copy['algorithm_params'] = algorithm_params
362+
345363
return cls(**params_copy)
346364

347365

@@ -424,10 +442,10 @@ def to_dict(self) -> dict[str, Any]:
424442
return dataclasses.asdict(
425443
self,
426444
dict_factory=lambda x: { # pylint: disable=g-long-lambda
427-
k: v
445+
k: (dict(v) if isinstance(v, Mapping) and not isinstance(v, dict) else v)
428446
for (k, v) in x
429447
# Skip None and empty dict values.
430-
if v is not None and not (isinstance(v, dict) and not v)
448+
if v is not None and not (isinstance(v, (dict, Mapping)) and not v)
431449
},
432450
)
433451

ai_edge_quantizer/recipe_manager.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -205,16 +205,26 @@ def get_quantization_recipe(self) -> ModelQuantizationRecipe:
205205
Returns:
206206
A list of quantization configs in the recipe.
207207
"""
208-
ret = []
209-
for _, scope_config in self._scope_configs.items():
210-
for quant_config in scope_config:
211-
config = dict()
212-
config['regex'] = quant_config.regex
213-
config['operation'] = quant_config.operation
214-
config['algorithm_key'] = quant_config.algorithm_key
215-
config['op_config'] = quant_config.op_config.to_dict()
216-
ret.append(config)
217-
return ret
208+
recipe = []
209+
for scope, op_recipes in self._scope_configs.items():
210+
for op_recipe in op_recipes:
211+
recipe_dict = dataclasses.asdict(
212+
op_recipe,
213+
dict_factory=lambda x: { # pylint: disable=g-long-lambda
214+
k: (
215+
dict(v)
216+
if isinstance(v, collections.abc.Mapping)
217+
and not isinstance(v, dict)
218+
else v
219+
)
220+
for (k, v) in x
221+
# Skip None and empty dict values.
222+
if v is not None
223+
and not (isinstance(v, (dict, collections.abc.Mapping)) and not v)
224+
},
225+
)
226+
recipe.append(recipe_dict)
227+
return recipe
218228

219229
def load_quantization_recipe(
220230
self, quantization_recipe: ModelQuantizationRecipe

ai_edge_quantizer/recipe_manager_test.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -636,20 +636,20 @@ def test_get_full_quantization_config(self):
636636
expected_full_quantization_config = [
637637
{
638638
'regex': '.*',
639-
'operation': '*',
639+
'operation': _TFLOpName.ALL_SUPPORTED,
640640
'algorithm_key': _AlgorithmName.MIN_MAX_UNIFORM_QUANT,
641641
'op_config': {
642642
'activation_tensor_config': {
643643
'num_bits': 8,
644644
'symmetric': False,
645645
'granularity': _QuantGranularity.TENSORWISE,
646-
'dtype': 'INT',
646+
'dtype': _TensorDataType.INT,
647647
},
648648
'weight_tensor_config': {
649649
'num_bits': 8,
650650
'symmetric': True,
651651
'granularity': _QuantGranularity.TENSORWISE,
652-
'dtype': 'INT',
652+
'dtype': _TensorDataType.INT,
653653
},
654654
# WEIGHT_ONLY.
655655
'compute_precision': _ComputePrecision.INTEGER,
@@ -660,11 +660,11 @@ def test_get_full_quantization_config(self):
660660
},
661661
{
662662
'regex': '.*',
663-
'operation': 'BATCH_MATMUL',
663+
'operation': _TFLOpName.BATCH_MATMUL,
664664
'algorithm_key': _AlgorithmName.MIN_MAX_UNIFORM_QUANT,
665665
'op_config': {
666666
'weight_tensor_config': {
667-
'dtype': 'INT',
667+
'dtype': _TensorDataType.INT,
668668
'num_bits': 8,
669669
'symmetric': True,
670670
'granularity': _QuantGranularity.TENSORWISE,
@@ -678,11 +678,11 @@ def test_get_full_quantization_config(self):
678678
},
679679
{
680680
'regex': '.*/Dense/.*',
681-
'operation': '*',
681+
'operation': _TFLOpName.ALL_SUPPORTED,
682682
'algorithm_key': _AlgorithmName.MIN_MAX_UNIFORM_QUANT,
683683
'op_config': {
684684
'weight_tensor_config': {
685-
'dtype': 'INT',
685+
'dtype': _TensorDataType.INT,
686686
'num_bits': 4,
687687
'symmetric': True,
688688
'granularity': _QuantGranularity.TENSORWISE,
@@ -696,11 +696,11 @@ def test_get_full_quantization_config(self):
696696
},
697697
{
698698
'regex': '.*/Dense_1/.*',
699-
'operation': 'FULLY_CONNECTED',
699+
'operation': _TFLOpName.FULLY_CONNECTED,
700700
'algorithm_key': _AlgorithmName.MIN_MAX_UNIFORM_QUANT,
701701
'op_config': {
702702
'weight_tensor_config': {
703-
'dtype': 'INT',
703+
'dtype': _TensorDataType.INT,
704704
'num_bits': 6,
705705
'symmetric': True,
706706
'granularity': _QuantGranularity.TENSORWISE,
@@ -714,11 +714,11 @@ def test_get_full_quantization_config(self):
714714
},
715715
{
716716
'regex': '.*/Dense_1/.*',
717-
'operation': 'BATCH_MATMUL',
717+
'operation': _TFLOpName.BATCH_MATMUL,
718718
'algorithm_key': _AlgorithmName.MIN_MAX_UNIFORM_QUANT,
719719
'op_config': {
720720
'weight_tensor_config': {
721-
'dtype': 'INT',
721+
'dtype': _TensorDataType.INT,
722722
'num_bits': 3,
723723
'symmetric': True,
724724
'granularity': _QuantGranularity.TENSORWISE,
@@ -987,6 +987,28 @@ def test_need_calibration_true(self):
987987
)
988988
self.assertTrue(self._recipe_manager.need_calibration())
989989

990+
def test_get_hadamard_with_max_size(self):
991+
self._recipe_manager.add_quantization_config(
992+
regex='.*/Dense/.*',
993+
operation_name=_TFLOpName.FULLY_CONNECTED,
994+
algorithm_key=_AlgorithmName.HADAMARD_ROTATION,
995+
op_config=qtyping.OpQuantizationConfig(
996+
weight_tensor_config=_TensorQuantConfig(
997+
num_bits=8, algorithm_params={'max_hadamard_size': 1024}
998+
),
999+
compute_precision=_ComputePrecision.INTEGER,
1000+
),
1001+
)
1002+
alg_key, op_config = self._recipe_manager.get_quantization_configs(
1003+
_TFLOpName.FULLY_CONNECTED, 'model/Dense/op'
1004+
)
1005+
self.assertEqual(alg_key, _AlgorithmName.HADAMARD_ROTATION)
1006+
weight_tensor_config = op_config.weight_tensor_config
1007+
assert weight_tensor_config is not None
1008+
self.assertEqual(
1009+
weight_tensor_config.algorithm_params['max_hadamard_size'], 1024
1010+
)
1011+
9901012

9911013
if __name__ == '__main__':
9921014
absltest.main()

ai_edge_quantizer/recipes/default_a16w8_recipe.json

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,19 @@
88
"num_bits": 16,
99
"symmetric": true,
1010
"granularity": "TENSORWISE",
11-
"dtype": "INT"
11+
"dtype": "INT",
12+
"algorithm_params": {
13+
"max_hadamard_size": 0
14+
}
1215
},
1316
"weight_tensor_config": {
1417
"num_bits": 8,
1518
"symmetric": true,
1619
"granularity": "CHANNELWISE",
17-
"dtype": "INT"
20+
"dtype": "INT",
21+
"algorithm_params": {
22+
"max_hadamard_size": 0
23+
}
1824
},
1925
"compute_precision": "INTEGER",
2026
"explicit_dequantize": false,

0 commit comments

Comments
 (0)