Skip to content

Commit b066596

Browse files
phoenix-meadowlarkcopybara-github
authored andcommitted
Add support asymmetric fake-quantization to AQTv2.
Integration of native quantization with biases will require computing the cross terms. See [#725](#725) Itemized changes: - Add `IntAsymmetric` to handle asymmetric integer numerics. - this class forgoes some of the more research-y parameters present on `IntSymmetric`. - Add `MinMaxCalibration` to calculate the scale and bias for asymmetric quantization. I additionally tested this change by training MNIST models using `flax_e2e_model`. With symmetric quantization the model fails to converge for `config.config_v4(fwd_bits=2, dlhs_bits=None, drhs_bits=None)` (due to `NaN` losses). With asymmetric quantization the model converges even with `config.config_v4(fwd_bits=2, dlhs_bits=2, drhs_bits=4)`. PiperOrigin-RevId: 651580879
1 parent b907430 commit b066596

File tree

5 files changed

+373
-50
lines changed

5 files changed

+373
-50
lines changed

aqt/jax/v2/aqt_dot_general_test.py

Lines changed: 101 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def _modify_dg(
207207
fwd_lhs_tricky_clip_and_round: bool = False,
208208
local_aqt: aqt.LocalAqt | None = None,
209209
clip_gradient: bool = False,
210+
use_asymmetric: bool = False,
210211
) -> aqt.DotGeneral:
211212
dg = copy.deepcopy(readonly_dg)
212213
if fwd_lhs_tricky_clip_and_round:
@@ -256,11 +257,15 @@ def _disable_quant_types(c, on_lhs=True, on_rhs=True):
256257
# that the scales are not too large.
257258
def disable_quant(c):
258259
_disable_quant_types(c)
259-
if isinstance(c.dg_quantizer.lhs.numerics, int_numerics.IntSymmetric):
260+
int_numerics_types = (
261+
int_numerics.IntSymmetric,
262+
int_numerics.IntAsymmetric,
263+
)
264+
if isinstance(c.dg_quantizer.lhs.numerics, int_numerics_types):
260265
c.dg_quantizer.lhs.numerics = (
261266
c.dg_quantizer.lhs.numerics.replace(round=False)
262267
)
263-
if isinstance(c.dg_quantizer.rhs.numerics, int_numerics.IntSymmetric):
268+
if isinstance(c.dg_quantizer.rhs.numerics, int_numerics_types):
264269
c.dg_quantizer.rhs.numerics = (
265270
c.dg_quantizer.rhs.numerics.replace(round=False)
266271
)
@@ -291,6 +296,11 @@ def disable_quant(c):
291296
dg.fwd.dg_quantizer.rhs.numerics.replace(clip_gradient=clip_gradient)
292297
)
293298

299+
if use_asymmetric:
300+
# TODO(aqt): use native asymmetric quantization once it is supported.
301+
# https://github.com/google/aqt/issues/725
302+
config.set_asymmetric_quantization(dg, use_fake_quant=True)
303+
294304
return dg
295305

296306

@@ -307,6 +317,7 @@ def _aqt_dg_full_lr_diff(
307317
readonly_dg: aqt.DotGeneral,
308318
dims: jax.lax.DotDimensionNumbers,
309319
clip_gradient: bool = False,
320+
use_asymmetric: bool = False,
310321
) -> Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
311322
dg = _modify_dg(
312323
readonly_dg,
@@ -319,6 +330,7 @@ def _aqt_dg_full_lr_diff(
319330
fwd_lhs_tricky_clip_and_round=fwd_lhs_tricky_clip_and_round,
320331
local_aqt=local_aqt,
321332
clip_gradient=clip_gradient,
333+
use_asymmetric=use_asymmetric,
322334
)
323335
dg = config.set_context(dg, key=jax.random.PRNGKey(4), train_step=None)
324336
return lambda lhs, rhs: dg(lhs, rhs, dims)
@@ -335,6 +347,7 @@ def _aqt_dg_full(
335347
readonly_dg: aqt.DotGeneral,
336348
dims: jax.lax.DotDimensionNumbers,
337349
clip_gradient: bool = False,
350+
use_asymmetric: bool = False,
338351
) -> Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
339352
return _aqt_dg_full_lr_diff(
340353
lhs_dequant_mode=dequant_mode,
@@ -348,6 +361,7 @@ def _aqt_dg_full(
348361
readonly_dg=readonly_dg,
349362
dims=dims,
350363
clip_gradient=clip_gradient,
364+
use_asymmetric=use_asymmetric,
351365
)
352366

353367

@@ -359,13 +373,15 @@ def _aqt_dg_raw_lr_diff(
359373
*,
360374
readonly_dg: aqt.DotGeneral,
361375
dims: jax.lax.DotDimensionNumbers,
376+
use_asymmetric: bool = False,
362377
) -> Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
363378
dg = _modify_dg(
364379
readonly_dg,
365380
lhs_dequant_mode=lhs_dequant_mode,
366381
rhs_dequant_mode=rhs_dequant_mode,
367382
lhs_calibration_mode=lhs_calibration_mode,
368383
rhs_calibration_mode=rhs_calibration_mode,
384+
use_asymmetric=use_asymmetric,
369385
)
370386
dg = config.set_context(dg, key=jax.random.PRNGKey(4), train_step=None)
371387
dg.fwd.dg_quantizer.init_calibration()
@@ -378,6 +394,7 @@ def _aqt_dg_raw(
378394
*,
379395
readonly_dg: aqt.DotGeneral,
380396
dims: jax.lax.DotDimensionNumbers,
397+
use_asymmetric: bool = False,
381398
) -> Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
382399
return _aqt_dg_raw_lr_diff(
383400
dequant_mode,
@@ -386,6 +403,7 @@ def _aqt_dg_raw(
386403
calibration_mode,
387404
readonly_dg=readonly_dg,
388405
dims=dims,
406+
use_asymmetric=use_asymmetric,
389407
)
390408

391409

@@ -557,6 +575,15 @@ def test_dot_general_calibration_with_contracting_axis(
557575
dtype=jnp.float32,
558576
clip_gradient=False,
559577
):
578+
is_quantized = not all([
579+
isinstance(dg.fwd.dg_quantizer.lhs.numerics, no_numerics.NoNumerics),
580+
isinstance(dg.fwd.dg_quantizer.rhs.numerics, no_numerics.NoNumerics),
581+
isinstance(dg.dlhs.dg_quantizer.lhs.numerics, no_numerics.NoNumerics),
582+
isinstance(dg.dlhs.dg_quantizer.rhs.numerics, no_numerics.NoNumerics),
583+
isinstance(dg.drhs.dg_quantizer.lhs.numerics, no_numerics.NoNumerics),
584+
isinstance(dg.drhs.dg_quantizer.rhs.numerics, no_numerics.NoNumerics),
585+
])
586+
560587
readonly_dg = dg
561588
del dg
562589

@@ -571,9 +598,24 @@ def test_dot_general_calibration_with_contracting_axis(
571598
dims=dims,
572599
clip_gradient=clip_gradient,
573600
)
601+
asym_dg_full = functools.partial(
602+
_aqt_dg_full,
603+
readonly_dg=readonly_dg,
604+
dims=dims,
605+
clip_gradient=clip_gradient,
606+
# This should be removed once asymmetric quant supports use_fwd_quant.
607+
use_fwd_quant=False,
608+
use_asymmetric=True,
609+
)
574610
aqt_dg_raw = functools.partial(
575611
_aqt_dg_raw, readonly_dg=readonly_dg, dims=dims
576612
)
613+
asym_dg_raw = functools.partial(
614+
_aqt_dg_raw,
615+
readonly_dg=readonly_dg,
616+
dims=dims,
617+
use_asymmetric=True,
618+
)
577619
modify_dg = functools.partial(_modify_dg, readonly_dg=readonly_dg)
578620
check = functools.partial(_check_result_eq, lhs=lhs, rhs=rhs, gra=gra)
579621

@@ -609,6 +651,20 @@ def test_dot_general_calibration_with_contracting_axis(
609651
dict(test_gradient=False),
610652
),
611653
])
654+
check([
655+
("default ", asym_dg_full(aqt.DequantMode.OUTPUT), dict()),
656+
("FQ ", asym_dg_full(aqt.DequantMode.THIS_INPUT), dict()),
657+
(
658+
"raw fwd ",
659+
asym_dg_raw(aqt.DequantMode.OUTPUT),
660+
dict(test_gradient=False),
661+
),
662+
(
663+
"raw fwd FQ ",
664+
asym_dg_raw(aqt.DequantMode.THIS_INPUT),
665+
dict(test_gradient=False),
666+
),
667+
])
612668

613669
check([
614670
(
@@ -631,6 +687,30 @@ def test_dot_general_calibration_with_contracting_axis(
631687
),
632688
])
633689

690+
if is_quantized:
691+
# Asymmetric quantization does not currently support forward quantization.
692+
with self.assertRaisesRegex(NotImplementedError, r"biases.*forward"):
693+
check([
694+
(
695+
"fwd_quant=F",
696+
aqt_dg_full(
697+
aqt.DequantMode.OUTPUT,
698+
use_fwd_quant=False,
699+
use_asymmetric=True,
700+
),
701+
dict(),
702+
),
703+
(
704+
"fwd_quant=T",
705+
aqt_dg_full(
706+
aqt.DequantMode.OUTPUT,
707+
use_fwd_quant=True,
708+
use_asymmetric=True,
709+
),
710+
dict(),
711+
),
712+
])
713+
634714
check([
635715
(
636716
"default ",
@@ -641,14 +721,32 @@ def test_dot_general_calibration_with_contracting_axis(
641721
dict(),
642722
),
643723
(
644-
"default ",
724+
"FQ ",
645725
aqt_dg_full(
646726
aqt.DequantMode.THIS_INPUT,
647727
local_aqt=aqt.LocalAqt(contraction_axis_shard_count=2),
648728
),
649729
dict(),
650730
),
651731
])
732+
check([
733+
(
734+
"default ",
735+
asym_dg_full(
736+
aqt.DequantMode.OUTPUT,
737+
local_aqt=aqt.LocalAqt(contraction_axis_shard_count=2),
738+
),
739+
dict(),
740+
),
741+
(
742+
"FQ ",
743+
asym_dg_full(
744+
aqt.DequantMode.THIS_INPUT,
745+
local_aqt=aqt.LocalAqt(contraction_axis_shard_count=2),
746+
),
747+
dict(),
748+
),
749+
])
652750

653751
if isinstance(
654752
readonly_dg.fwd.dg_quantizer.lhs.numerics,

aqt/jax/v2/calibration.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Union
1919
from aqt.jax.v2 import aqt_tensor
2020
from aqt.jax.v2 import utils
21+
from aqt.jax.v2.numerics import int_numerics
2122
from aqt.jax.v2.numerics import numerics
2223
import jax
2324
import jax.numpy as jnp
@@ -392,3 +393,49 @@ def _calculate_snr(
392393
snr = jnp.log(1 + signal / noise)
393394

394395
return snr
396+
397+
398+
@utils.flax_slots_kw_only_dataclass
399+
class MinMaxCalibration(Calibration):
400+
"""Calibration between the min and max values.
401+
402+
Attributes:
403+
eps: Optional epsilon to add to the bound to avoid division by zero. Inf
404+
filtering is also performed by QTensor.quant() after division.
405+
"""
406+
407+
eps: float | None = None
408+
409+
def get_scale_and_bias(
410+
self,
411+
x: jnp.ndarray,
412+
shared_axes: Sequence[utils.AxisIdx] | None,
413+
numerics_: int_numerics.IntAsymmetric,
414+
context: utils.Context | None = None,
415+
) -> tuple[list[jnp.ndarray], list[jnp.ndarray]]:
416+
del context
417+
msg = (
418+
'Perhaps you are using DequantMode.THIS_INPUT (fake_quant) and forgot'
419+
' to set them.'
420+
)
421+
assert shared_axes is not None, msg
422+
if not isinstance(numerics_, int_numerics.IntAsymmetric):
423+
raise NotImplementedError(
424+
'MinMaxCalibration only supports int_numerics.IntAsymmetric, but got '
425+
f'{numerics}'
426+
)
427+
dtype = self.dtype if self.dtype is not None else x.dtype
428+
429+
# Scale the full width of x to the full width of the quantization range.
430+
x_min = jnp.min(x, axis=shared_axes, keepdims=True)
431+
x_max = jnp.max(x, axis=shared_axes, keepdims=True)
432+
bound = x_max - x_min
433+
if self.eps is not None:
434+
bound += self.eps
435+
scale = bound / numerics_.get_quant_bound()
436+
437+
# Calculate bias s.t. quant(min(x)) = (min(x) + bias) / scale = quant_min.
438+
quant_min, _ = numerics_.get_quant_range()
439+
bias = quant_min * scale - x_min
440+
441+
return [scale.astype(dtype)], [bias.astype(dtype)]

aqt/jax/v2/config.py

Lines changed: 82 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,35 @@ def set_dg_raw_context(cfg_raw: DotGeneralRaw, key: Optional[jax.Array]):
8888
return ret_cfg
8989

9090

91-
def set_fwd_dequant_mode(
92-
cfg: DotGeneral,
91+
def set_dequant_mode(
92+
cfg: DotGeneralRaw,
9393
*,
9494
lhs_dequant_mode: Optional[DequantMode] = None,
9595
rhs_dequant_mode: Optional[DequantMode] = None,
9696
):
97+
"""Sets the dequant mode for the lhs and rhs of a single dot general raw."""
9798
if lhs_dequant_mode is not None:
98-
cfg.fwd.lhs.dequant_mode = lhs_dequant_mode
99+
cfg.lhs.dequant_mode = lhs_dequant_mode
99100
if rhs_dequant_mode is not None:
100-
cfg.fwd.rhs.dequant_mode = rhs_dequant_mode
101+
cfg.rhs.dequant_mode = rhs_dequant_mode
102+
103+
fake_quant = DequantMode.THIS_INPUT in [lhs_dequant_mode, rhs_dequant_mode]
104+
if fake_quant and jnp.issubdtype(cfg.dg_accumulator_dtype, jnp.integer):
105+
# Fake-quantization is not compatible with integer accumulation.
106+
cfg.dg_accumulator_dtype = None
107+
108+
109+
def set_fwd_dequant_mode(
110+
cfg: DotGeneral,
111+
*,
112+
lhs_dequant_mode: Optional[DequantMode] = None,
113+
rhs_dequant_mode: Optional[DequantMode] = None,
114+
):
115+
set_dequant_mode(
116+
cfg.fwd,
117+
lhs_dequant_mode=lhs_dequant_mode,
118+
rhs_dequant_mode=rhs_dequant_mode,
119+
)
101120

102121

103122
def set_fwd_calibration_mode(
@@ -404,6 +423,65 @@ def set_bits(
404423
return cfg
405424

406425

426+
def _get_asym_numerics(numerics_: numerics.AqtNumerics):
427+
"""Gets the asymmetric equivalent of the given numerics."""
428+
if isinstance(
429+
numerics_, (int_numerics.IntSymmetric, int_numerics.IntAsymmetric)
430+
):
431+
# pytype: disable=attribute-error
432+
return int_numerics.IntAsymmetric(
433+
bits=numerics_.bits,
434+
clip=numerics_.clip,
435+
clip_gradient=numerics_.clip_gradient,
436+
round=numerics_.round,
437+
noise_fn=numerics_.noise_fn,
438+
dtype=numerics_.dtype,
439+
)
440+
# pytype: enable=attribute-error
441+
elif isinstance(numerics_, no_numerics.NoNumerics):
442+
return numerics_
443+
else:
444+
raise NotImplementedError(
445+
'Asymmetric quantization currently only supports integer numerics,'
446+
f' but got {numerics_}'
447+
)
448+
449+
450+
def _set_asymmetric_quantization(cfg: DotGeneralRaw, use_fake_quant: bool):
451+
"""Replaces symmetric quantization with asymmetric quantization."""
452+
set_numerics(
453+
cfg,
454+
_get_asym_numerics(cfg.dg_quantizer.lhs.numerics),
455+
_get_asym_numerics(cfg.dg_quantizer.rhs.numerics),
456+
)
457+
458+
def replace_calibration(quantizer: aqt_quantizer.Quantizer):
459+
if isinstance(quantizer.calibration, functools.partial):
460+
quantizer.calibration = functools.partial(
461+
calibration.MinMaxCalibration, **quantizer.calibration.keywords
462+
)
463+
else:
464+
quantizer.calibration = calibration.MinMaxCalibration
465+
466+
replace_calibration(cfg.dg_quantizer.lhs)
467+
replace_calibration(cfg.dg_quantizer.rhs)
468+
469+
# Only fake quantization currently supports quantization biases.
470+
if use_fake_quant:
471+
set_dequant_mode(
472+
cfg,
473+
lhs_dequant_mode=DequantMode.THIS_INPUT,
474+
rhs_dequant_mode=DequantMode.THIS_INPUT,
475+
)
476+
477+
478+
def set_asymmetric_quantization(cfg: DotGeneral, *, use_fake_quant: bool):
479+
"""Replaces symmetric quantization with asymmetric quantization."""
480+
_set_asymmetric_quantization(cfg.fwd, use_fake_quant)
481+
_set_asymmetric_quantization(cfg.dlhs, use_fake_quant)
482+
_set_asymmetric_quantization(cfg.drhs, use_fake_quant)
483+
484+
407485
def set_scale_and_bias_dtype(cfg: DotGeneral, dtype: jnp.dtype):
408486
"""Set the dtype for all scales and biases in the given DotGeneral config."""
409487
assert isinstance(

0 commit comments

Comments
 (0)