@@ -207,6 +207,7 @@ def _modify_dg(
207
207
fwd_lhs_tricky_clip_and_round : bool = False ,
208
208
local_aqt : aqt .LocalAqt | None = None ,
209
209
clip_gradient : bool = False ,
210
+ use_asymmetric : bool = False ,
210
211
) -> aqt .DotGeneral :
211
212
dg = copy .deepcopy (readonly_dg )
212
213
if fwd_lhs_tricky_clip_and_round :
@@ -256,11 +257,15 @@ def _disable_quant_types(c, on_lhs=True, on_rhs=True):
256
257
# that the scales are not too large.
257
258
def disable_quant (c ):
258
259
_disable_quant_types (c )
259
- if isinstance (c .dg_quantizer .lhs .numerics , int_numerics .IntSymmetric ):
260
+ int_numerics_types = (
261
+ int_numerics .IntSymmetric ,
262
+ int_numerics .IntAsymmetric ,
263
+ )
264
+ if isinstance (c .dg_quantizer .lhs .numerics , int_numerics_types ):
260
265
c .dg_quantizer .lhs .numerics = (
261
266
c .dg_quantizer .lhs .numerics .replace (round = False )
262
267
)
263
- if isinstance (c .dg_quantizer .rhs .numerics , int_numerics . IntSymmetric ):
268
+ if isinstance (c .dg_quantizer .rhs .numerics , int_numerics_types ):
264
269
c .dg_quantizer .rhs .numerics = (
265
270
c .dg_quantizer .rhs .numerics .replace (round = False )
266
271
)
@@ -291,6 +296,11 @@ def disable_quant(c):
291
296
dg .fwd .dg_quantizer .rhs .numerics .replace (clip_gradient = clip_gradient )
292
297
)
293
298
299
+ if use_asymmetric :
300
+ # TODO(aqt): use native asymmetric quantization once it is supported.
301
+ # https://github.com/google/aqt/issues/725
302
+ config .set_asymmetric_quantization (dg , use_fake_quant = True )
303
+
294
304
return dg
295
305
296
306
@@ -307,6 +317,7 @@ def _aqt_dg_full_lr_diff(
307
317
readonly_dg : aqt .DotGeneral ,
308
318
dims : jax .lax .DotDimensionNumbers ,
309
319
clip_gradient : bool = False ,
320
+ use_asymmetric : bool = False ,
310
321
) -> Callable [[jnp .ndarray , jnp .ndarray ], jnp .ndarray ]:
311
322
dg = _modify_dg (
312
323
readonly_dg ,
@@ -319,6 +330,7 @@ def _aqt_dg_full_lr_diff(
319
330
fwd_lhs_tricky_clip_and_round = fwd_lhs_tricky_clip_and_round ,
320
331
local_aqt = local_aqt ,
321
332
clip_gradient = clip_gradient ,
333
+ use_asymmetric = use_asymmetric ,
322
334
)
323
335
dg = config .set_context (dg , key = jax .random .PRNGKey (4 ), train_step = None )
324
336
return lambda lhs , rhs : dg (lhs , rhs , dims )
@@ -335,6 +347,7 @@ def _aqt_dg_full(
335
347
readonly_dg : aqt .DotGeneral ,
336
348
dims : jax .lax .DotDimensionNumbers ,
337
349
clip_gradient : bool = False ,
350
+ use_asymmetric : bool = False ,
338
351
) -> Callable [[jnp .ndarray , jnp .ndarray ], jnp .ndarray ]:
339
352
return _aqt_dg_full_lr_diff (
340
353
lhs_dequant_mode = dequant_mode ,
@@ -348,6 +361,7 @@ def _aqt_dg_full(
348
361
readonly_dg = readonly_dg ,
349
362
dims = dims ,
350
363
clip_gradient = clip_gradient ,
364
+ use_asymmetric = use_asymmetric ,
351
365
)
352
366
353
367
@@ -359,13 +373,15 @@ def _aqt_dg_raw_lr_diff(
359
373
* ,
360
374
readonly_dg : aqt .DotGeneral ,
361
375
dims : jax .lax .DotDimensionNumbers ,
376
+ use_asymmetric : bool = False ,
362
377
) -> Callable [[jnp .ndarray , jnp .ndarray ], jnp .ndarray ]:
363
378
dg = _modify_dg (
364
379
readonly_dg ,
365
380
lhs_dequant_mode = lhs_dequant_mode ,
366
381
rhs_dequant_mode = rhs_dequant_mode ,
367
382
lhs_calibration_mode = lhs_calibration_mode ,
368
383
rhs_calibration_mode = rhs_calibration_mode ,
384
+ use_asymmetric = use_asymmetric ,
369
385
)
370
386
dg = config .set_context (dg , key = jax .random .PRNGKey (4 ), train_step = None )
371
387
dg .fwd .dg_quantizer .init_calibration ()
@@ -378,6 +394,7 @@ def _aqt_dg_raw(
378
394
* ,
379
395
readonly_dg : aqt .DotGeneral ,
380
396
dims : jax .lax .DotDimensionNumbers ,
397
+ use_asymmetric : bool = False ,
381
398
) -> Callable [[jnp .ndarray , jnp .ndarray ], jnp .ndarray ]:
382
399
return _aqt_dg_raw_lr_diff (
383
400
dequant_mode ,
@@ -386,6 +403,7 @@ def _aqt_dg_raw(
386
403
calibration_mode ,
387
404
readonly_dg = readonly_dg ,
388
405
dims = dims ,
406
+ use_asymmetric = use_asymmetric ,
389
407
)
390
408
391
409
@@ -557,6 +575,15 @@ def test_dot_general_calibration_with_contracting_axis(
557
575
dtype = jnp .float32 ,
558
576
clip_gradient = False ,
559
577
):
578
+ is_quantized = not all ([
579
+ isinstance (dg .fwd .dg_quantizer .lhs .numerics , no_numerics .NoNumerics ),
580
+ isinstance (dg .fwd .dg_quantizer .rhs .numerics , no_numerics .NoNumerics ),
581
+ isinstance (dg .dlhs .dg_quantizer .lhs .numerics , no_numerics .NoNumerics ),
582
+ isinstance (dg .dlhs .dg_quantizer .rhs .numerics , no_numerics .NoNumerics ),
583
+ isinstance (dg .drhs .dg_quantizer .lhs .numerics , no_numerics .NoNumerics ),
584
+ isinstance (dg .drhs .dg_quantizer .rhs .numerics , no_numerics .NoNumerics ),
585
+ ])
586
+
560
587
readonly_dg = dg
561
588
del dg
562
589
@@ -571,9 +598,24 @@ def test_dot_general_calibration_with_contracting_axis(
571
598
dims = dims ,
572
599
clip_gradient = clip_gradient ,
573
600
)
601
+ asym_dg_full = functools .partial (
602
+ _aqt_dg_full ,
603
+ readonly_dg = readonly_dg ,
604
+ dims = dims ,
605
+ clip_gradient = clip_gradient ,
606
+ # This should be removed once asymmetric quant supports use_fwd_quant.
607
+ use_fwd_quant = False ,
608
+ use_asymmetric = True ,
609
+ )
574
610
aqt_dg_raw = functools .partial (
575
611
_aqt_dg_raw , readonly_dg = readonly_dg , dims = dims
576
612
)
613
+ asym_dg_raw = functools .partial (
614
+ _aqt_dg_raw ,
615
+ readonly_dg = readonly_dg ,
616
+ dims = dims ,
617
+ use_asymmetric = True ,
618
+ )
577
619
modify_dg = functools .partial (_modify_dg , readonly_dg = readonly_dg )
578
620
check = functools .partial (_check_result_eq , lhs = lhs , rhs = rhs , gra = gra )
579
621
@@ -609,6 +651,20 @@ def test_dot_general_calibration_with_contracting_axis(
609
651
dict (test_gradient = False ),
610
652
),
611
653
])
654
+ check ([
655
+ ("default " , asym_dg_full (aqt .DequantMode .OUTPUT ), dict ()),
656
+ ("FQ " , asym_dg_full (aqt .DequantMode .THIS_INPUT ), dict ()),
657
+ (
658
+ "raw fwd " ,
659
+ asym_dg_raw (aqt .DequantMode .OUTPUT ),
660
+ dict (test_gradient = False ),
661
+ ),
662
+ (
663
+ "raw fwd FQ " ,
664
+ asym_dg_raw (aqt .DequantMode .THIS_INPUT ),
665
+ dict (test_gradient = False ),
666
+ ),
667
+ ])
612
668
613
669
check ([
614
670
(
@@ -631,6 +687,30 @@ def test_dot_general_calibration_with_contracting_axis(
631
687
),
632
688
])
633
689
690
+ if is_quantized :
691
+ # Asymmetric quantization does not currently support forward quantization.
692
+ with self .assertRaisesRegex (NotImplementedError , r"biases.*forward" ):
693
+ check ([
694
+ (
695
+ "fwd_quant=F" ,
696
+ aqt_dg_full (
697
+ aqt .DequantMode .OUTPUT ,
698
+ use_fwd_quant = False ,
699
+ use_asymmetric = True ,
700
+ ),
701
+ dict (),
702
+ ),
703
+ (
704
+ "fwd_quant=T" ,
705
+ aqt_dg_full (
706
+ aqt .DequantMode .OUTPUT ,
707
+ use_fwd_quant = True ,
708
+ use_asymmetric = True ,
709
+ ),
710
+ dict (),
711
+ ),
712
+ ])
713
+
634
714
check ([
635
715
(
636
716
"default " ,
@@ -641,14 +721,32 @@ def test_dot_general_calibration_with_contracting_axis(
641
721
dict (),
642
722
),
643
723
(
644
- "default " ,
724
+ "FQ " ,
645
725
aqt_dg_full (
646
726
aqt .DequantMode .THIS_INPUT ,
647
727
local_aqt = aqt .LocalAqt (contraction_axis_shard_count = 2 ),
648
728
),
649
729
dict (),
650
730
),
651
731
])
732
+ check ([
733
+ (
734
+ "default " ,
735
+ asym_dg_full (
736
+ aqt .DequantMode .OUTPUT ,
737
+ local_aqt = aqt .LocalAqt (contraction_axis_shard_count = 2 ),
738
+ ),
739
+ dict (),
740
+ ),
741
+ (
742
+ "FQ " ,
743
+ asym_dg_full (
744
+ aqt .DequantMode .THIS_INPUT ,
745
+ local_aqt = aqt .LocalAqt (contraction_axis_shard_count = 2 ),
746
+ ),
747
+ dict (),
748
+ ),
749
+ ])
652
750
653
751
if isinstance (
654
752
readonly_dg .fwd .dg_quantizer .lhs .numerics ,
0 commit comments