Skip to content

Commit ec3402f

Browse files
ai-edge-botcopybara-github
authored andcommitted
Add max_hadamard_size parameter for Hadamard rotations.
PiperOrigin-RevId: 888789202
1 parent 6fb7140 commit ec3402f

File tree

10 files changed

+185
-51
lines changed

10 files changed

+185
-51
lines changed

ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,24 @@ def _make_hadamard_matrix(size: int) -> np.ndarray:
6161
h_int8 = np.kron(h_int8, h2)
6262
current_size *= 2
6363
if current_size not in _HADAMARD_MATRIX_CACHE:
64-
h_norm = h_int8 / np.sqrt(size, dtype=np.float32)
64+
h_norm = h_int8 / np.sqrt(current_size, dtype=np.float32)
6565
_HADAMARD_MATRIX_CACHE[current_size] = h_norm
66+
else:
67+
h_norm = _HADAMARD_MATRIX_CACHE[current_size]
6668
return h_norm
6769

6870

6971
def _rotate_with_diagonal_hadamard(
7072
tensor_content: np.ndarray,
7173
axis: int,
74+
max_size: int | None = None,
7275
):
7376
"""Quantizes the given float array using the diagonal Hadamard algorithm.
7477
7578
Args:
7679
tensor_content: The float array to quantize.
7780
axis: The axis of the tensor to rotate.
81+
max_size: The maximum size of the Hadamard matrix.
7882
7983
Returns:
8084
A tuple containing the quantized array and the recovered array.
@@ -93,6 +97,8 @@ def _rotate_with_diagonal_hadamard(
9397
# tile this Hadamard matrix along the diagonal. 2**30 is just a large power
9498
# of 2 to calculate this factor.
9599
hadamard_size = np.gcd(tensor_content.shape[axis], 2**30)
100+
if max_size:
101+
hadamard_size = min(hadamard_size, 1 << (max_size.bit_length() - 1))
96102
random_vector = np.ones(hadamard_size, dtype=np.int8)
97103

98104
# Use a canonical Hadamard matrix.
@@ -150,7 +156,9 @@ def get_tensor_quant_params(
150156

151157
# Rotate the tensor with a Hadamard matrix.
152158
w_rotated, hadamard_size, random_vector = _rotate_with_diagonal_hadamard(
153-
tensor_content, axis=reduce_axis
159+
tensor_content,
160+
axis=reduce_axis,
161+
max_size=tensor_quant_config.algorithm_params.get("max_hadamard_size"),
154162
)
155163

156164
# Get the quantized values of the rotated tensor.

ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Test Hadamard rotation materialization."""
1717

18+
import dataclasses
1819
import pathlib
1920

2021
from absl.testing import absltest
@@ -240,6 +241,36 @@ def test_get_tensor_quant_params_basic(self):
240241
if qparams.hadamard is not None:
241242
self.assertEqual(qparams.hadamard.hadamard_size, 32)
242243

244+
def test_get_tensor_quant_params_max_size(self):
245+
input_tensor = self._subgraph.tensors[self._fc_op.inputs[1]]
246+
buffer = self._graph_info.buffers[self._fc_buffer_id]
247+
np_buffer = np.frombuffer(buffer.data, dtype=np.float32).reshape(
248+
input_tensor.shape
249+
)
250+
# The original dimension is 32. The largest power of 2 factor is 32.
251+
# If we set algorithm_params to {'max_hadamard_size': 16}, then it should be
252+
# 16.
253+
new_op_quant_config = dataclasses.replace(
254+
self._op_info.op_quant_config,
255+
weight_tensor_config=qtyping.TensorQuantizationConfig(
256+
num_bits=8,
257+
symmetric=True,
258+
granularity=qtyping.QuantGranularity.CHANNELWISE,
259+
algorithm_params={"max_hadamard_size": 16},
260+
),
261+
)
262+
self._op_info = dataclasses.replace(
263+
self._op_info, op_quant_config=new_op_quant_config
264+
)
265+
qparams = hadamard_rotation.get_tensor_quant_params(
266+
self._op_info,
267+
self._op_info.op_quant_config.weight_tensor_config,
268+
np_buffer,
269+
self._tensor_name_to_qsv,
270+
)
271+
self.assertIsNotNone(qparams.hadamard)
272+
self.assertEqual(qparams.hadamard.hadamard_size, 16)
273+
243274
def test_get_tensor_quant_params_golden_1(self):
244275
test_data = np.ones((6, 6))
245276
# expected:

ai_edge_quantizer/algorithms/utils/common_utils.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -126,21 +126,36 @@ def check_if_valid_op_config(
126126
f"No policy was specified for op: {op_name} with config:"
127127
f" {op_quant_config}."
128128
)
129-
# The config_check_policy contains all possible valid configs, except for
130-
# variations in the min_weight_elements field (it's set to 0 for all of them).
131-
# min_weight_elements has to be ignored during policy check here because it
132-
# can be any non-negative integer, which means we can't list all possible
133-
# values in the policy.
134-
elif (
135-
dataclasses.replace(op_quant_config, min_weight_elements=0)
136-
not in config_check_policy[op_name]
137-
):
138-
error_msg = (
139-
f"Quantization config for op: {op_name} with config:"
140-
f" {op_quant_config} was not found in the policy."
141-
)
142129
else:
143-
check_passed = True
130+
# min_weight_elements and algorithm_params have to be ignored during
131+
# policy check here because they can be any non-negative integer or dict,
132+
# which means we can't list all possible values in the policy.
133+
op_quant_config_to_check = dataclasses.replace(
134+
op_quant_config, min_weight_elements=0
135+
)
136+
if op_quant_config_to_check.weight_tensor_config is not None:
137+
op_quant_config_to_check = dataclasses.replace(
138+
op_quant_config_to_check,
139+
weight_tensor_config=dataclasses.replace(
140+
op_quant_config_to_check.weight_tensor_config, algorithm_params={}
141+
),
142+
)
143+
if op_quant_config_to_check.activation_tensor_config is not None:
144+
op_quant_config_to_check = dataclasses.replace(
145+
op_quant_config_to_check,
146+
activation_tensor_config=dataclasses.replace(
147+
op_quant_config_to_check.activation_tensor_config,
148+
algorithm_params={},
149+
),
150+
)
151+
152+
if op_quant_config_to_check not in config_check_policy[op_name]:
153+
error_msg = (
154+
f"Quantization config for op: {op_name} with config:"
155+
f" {op_quant_config!r} was not found in the policy."
156+
)
157+
else:
158+
check_passed = True
144159

145160
if not check_passed:
146161
raise ValueError(

ai_edge_quantizer/algorithms/utils/common_utils_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,21 @@ def test_check_drq_config_with_non_default_min_weight_elements_succeeds(self):
224224
_TFLOpName.CONV_2D, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
225225
)
226226

227+
def test_check_config_with_non_default_algorithm_params_succeeds(self):
228+
op_quant_config = _OpQuantConfig(
229+
weight_tensor_config=_TensorQuantConfig(
230+
num_bits=8,
231+
granularity=qtyping.QuantGranularity.CHANNELWISE,
232+
algorithm_params={"max_hadamard_size": 1024},
233+
),
234+
compute_precision=_ComputePrecision.INTEGER, # DRQ.
235+
)
236+
common_utils.check_if_valid_op_config(
237+
_TFLOpName.FULLY_CONNECTED,
238+
op_quant_config,
239+
_DEFAULT_CONFIG_CHECK_POLICY,
240+
)
241+
227242
@parameterized.product(
228243
op_name=(_TFLOpName.FULLY_CONNECTED, _TFLOpName.CONV_2D),
229244
act_num_bits=(8, 16),

ai_edge_quantizer/qtyping.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,9 @@
2020
import copy
2121
import dataclasses
2222
import enum
23-
from typing import Any, Callable, Optional, Union
24-
23+
from typing import Any, Callable, Mapping, Optional, Union, TypeAlias
24+
from immutabledict import immutabledict
2525
import numpy as np
26-
from typing_extensions import TypeAlias
2726

2827
from ai_edge_litert.tools import flatbuffer_utils
2928

@@ -355,22 +354,36 @@ class TensorQuantizationConfig:
355354
quantization.
356355
dtype: The data type of the tensor.
357356
algorithm_key: The algorithm key to use for quantization.
357+
algorithm_params: Additional parameters for the quantization algorithm.
358358
"""
359359

360360
num_bits: int
361361
symmetric: bool = True
362362
granularity: QuantGranularity = QuantGranularity.TENSORWISE
363363
dtype: TensorDataType = TensorDataType.INT
364+
algorithm_params: Mapping[str, Any] = dataclasses.field(
365+
default_factory=immutabledict
366+
)
367+
368+
def __post_init__(self):
369+
if not isinstance(self.algorithm_params, immutabledict):
370+
object.__setattr__(
371+
self, 'algorithm_params', immutabledict(self.algorithm_params)
372+
)
364373

365374
def to_dict(self) -> dict[str, Any]:
366375
"""Converts ActivationQuantizationConfig to dict."""
367376
return dataclasses.asdict(
368377
self,
369378
dict_factory=lambda x: { # pylint: disable=g-long-lambda
370-
k: v
379+
k: (
380+
dict(v)
381+
if isinstance(v, Mapping) and not isinstance(v, dict)
382+
else v
383+
)
371384
for (k, v) in x
372385
# Skip None and empty dict values.
373-
if v is not None and not (isinstance(v, dict) and not v)
386+
if v is not None and not (isinstance(v, Mapping) and not v)
374387
},
375388
)
376389

@@ -380,7 +393,15 @@ def from_dict(cls, params: dict[str, Any]) -> 'TensorQuantizationConfig':
380393
params_copy = copy.deepcopy(params)
381394
# Process block_size config from legacy recipe.
382395
params_copy = _process_block_size(params_copy)
383-
return cls(**params_copy)
396+
397+
# Move any unknown fields to algorithm_params for backward compatibility.
398+
known_fields = {f.name for f in dataclasses.fields(cls)}
399+
algorithm_params = params_copy.pop('algorithm_params', {})
400+
for key in list(params_copy.keys()):
401+
if key not in known_fields:
402+
algorithm_params[key] = params_copy.pop(key)
403+
404+
return cls(algorithm_params=algorithm_params, **params_copy)
384405

385406

386407
def _process_block_size(params: dict[str, Any]) -> dict[str, Any]:
@@ -462,10 +483,14 @@ def to_dict(self) -> dict[str, Any]:
462483
return dataclasses.asdict(
463484
self,
464485
dict_factory=lambda x: { # pylint: disable=g-long-lambda
465-
k: v
486+
k: (
487+
dict(v)
488+
if isinstance(v, Mapping) and not isinstance(v, dict)
489+
else v
490+
)
466491
for (k, v) in x
467492
# Skip None and empty dict values.
468-
if v is not None and not (isinstance(v, dict) and not v)
493+
if v is not None and not (isinstance(v, Mapping) and not v)
469494
},
470495
)
471496

ai_edge_quantizer/recipe_manager.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import collections
1919
import dataclasses
2020
import re
21-
from typing import Any, Optional
21+
from typing import Any, Mapping, Optional
2222
from absl import logging
2323
from ai_edge_quantizer import algorithm_manager
2424
from ai_edge_quantizer import qtyping
@@ -205,16 +205,28 @@ def get_quantization_recipe(self) -> ModelQuantizationRecipe:
205205
Returns:
206206
A list of quantization configs in the recipe.
207207
"""
208-
ret = []
209-
for _, scope_config in self._scope_configs.items():
210-
for quant_config in scope_config:
211-
config = dict()
212-
config['regex'] = quant_config.regex
213-
config['operation'] = quant_config.operation
214-
config['algorithm_key'] = quant_config.algorithm_key
215-
config['op_config'] = quant_config.op_config.to_dict()
216-
ret.append(config)
217-
return ret
208+
recipe = []
209+
for _, op_recipes in self._scope_configs.items():
210+
for op_recipe in op_recipes:
211+
recipe_dict = dataclasses.asdict(
212+
op_recipe,
213+
dict_factory=lambda x: { # pylint: disable=g-long-lambda
214+
k: (
215+
dict(v)
216+
if isinstance(v, Mapping)
217+
and not isinstance(v, dict)
218+
else v
219+
)
220+
for (k, v) in x
221+
# Skip None and empty dict values.
222+
if v is not None
223+
and not (
224+
isinstance(v, (dict, Mapping)) and not v
225+
)
226+
},
227+
)
228+
recipe.append(recipe_dict)
229+
return recipe
218230

219231
def load_quantization_recipe(
220232
self, quantization_recipe: ModelQuantizationRecipe

0 commit comments

Comments
 (0)