Skip to content

Commit 097249a

Browse files
phoenix-meadowlarkcopybara-github
authored andcommitted
Add support asymmetric fake-quantization to AQTv2.
Integration of native quantization with biases will require computing the cross terms, likely in the AQT operation quantizer (`DefaultGeneralQuantizer`). Itemized changes: - `AqtNumerics`: - Rename `AqtNumerics.abs_val_mapped_to` to `AqtNumerics.get_scaled_bound` to reflect that the calibration bound may be span the whole quantization range (instead of ~half the range for a strictly linear transformation). - Refactor `IntNumerics` into `BaseIntNumerics`, `SymIntNumerics` and `AsymIntNumerics`. - `AsymIntNumerics` doesn't need `preserve_zero` or `preserve_max_val`. - Add `MinMaxCalibration`. I additionally tested this change by training MNIST models using `flax_e2e_model`. With symmetric quantization the model fails to converge for `config.config_v4(fwd_bits=2, dlhs_bits=None, drhs_bits=None)` (due to `NaN` losses). With asymmetric quantization the model converges even with `config.config_v4(fwd_bits=2, dlhs_bits=2, drhs_bits=4)`. PiperOrigin-RevId: 651580879
1 parent d2cfb75 commit 097249a

19 files changed

+926
-388
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ from aqt.jax.v2 import utils as aqt_utils
169169
from aqt.jax.v2.numerics import int_numerics
170170

171171
q = aqt_quantizer.Quantizer(
172-
numerics=int_numerics.IntNumerics(
172+
numerics=int_numerics.SymIntNumerics(
173173
bits=4,
174174
preserve_zero=True,
175175
preserve_max_val=True,

aqt/jax/v2/aqt_conv_general_test.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import functools
16+
1517
from absl.testing import absltest
1618
from absl.testing import parameterized
1719
from aqt.jax.v2 import aqt_quantizer
@@ -28,6 +30,16 @@ def rand_unif(shape, maxval, seed, dtype=jnp.float32):
2830
)
2931

3032

33+
def _apply_po2_scale(quantizer):
34+
calibration_cls = quantizer.calibration
35+
keywords = {}
36+
if isinstance(calibration_cls, functools.partial):
37+
keywords = calibration_cls.keywords
38+
calibration_cls = calibration_cls.func
39+
keywords.update(po2_scale=True)
40+
quantizer.calibration = functools.partial(calibration_cls, **keywords)
41+
42+
3143
class AqtConvGeneralTest(parameterized.TestCase):
3244

3345
@parameterized.parameters([
@@ -48,13 +60,17 @@ def test_conv_general_dilated(
4860
rhs_maxval=20.0,
4961
seed=0,
5062
):
51-
dg_raw_conv = aqt_conv.conv_general_dilated_make(2, lhs_bits, rhs_bits)
52-
63+
dg_raw_conv = aqt_conv.conv_general_dilated_make(
64+
2, lhs_bits, rhs_bits, initialize_calibration=False
65+
)
66+
# Power-of-2 scales allow FQ and AQT to be exactly the same.
67+
dg_quantizer = dg_raw_conv.dg_quantizer
5368
if dg_raw_conv.lhs:
54-
# Power-of-2 scales allow FQ and AQT to be exactly the same.
55-
dg_raw_conv.dg_quantizer.lhs.po2_scale = True
69+
_apply_po2_scale(dg_quantizer.lhs)
70+
dg_quantizer.lhs.init_calibration()
5671
if dg_raw_conv.rhs:
57-
dg_raw_conv.dg_quantizer.rhs.po2_scale = True
72+
_apply_po2_scale(dg_quantizer.rhs)
73+
dg_quantizer.rhs.init_calibration()
5874

5975
batch_n = 10
6076
contr_n = 20
@@ -94,12 +110,17 @@ def test_conv_general_dilated_quantized(
94110
seed=0,
95111
):
96112
"""Check that passing quantized lhs/rhs to aqt_conv_fn works."""
97-
dg_raw_conv = aqt_conv.conv_general_dilated_make(2, lhs_bits, rhs_bits)
113+
dg_raw_conv = aqt_conv.conv_general_dilated_make(
114+
2, lhs_bits, rhs_bits, initialize_calibration=False
115+
)
116+
# Power-of-2 scales allow FQ and AQT to be exactly the same.
117+
dg_quantizer = dg_raw_conv.dg_quantizer
98118
if dg_raw_conv.lhs:
99-
# Power-of-2 scales allow FQ and AQT to be exactly the same.
100-
dg_raw_conv.dg_quantizer.lhs.po2_scale = True
119+
_apply_po2_scale(dg_quantizer.lhs)
120+
dg_quantizer.lhs.init_calibration()
101121
if dg_raw_conv.rhs:
102-
dg_raw_conv.dg_quantizer.rhs.po2_scale = True
122+
_apply_po2_scale(dg_quantizer.rhs)
123+
dg_quantizer.rhs.init_calibration()
103124

104125
batch_n = 10
105126
contr_n = 20

0 commit comments

Comments
 (0)