Skip to content

Commit 163cc4c

Browse files
Cerebra Catalyst Teamcopybara-github
authored andcommitted
Internal prototype
PiperOrigin-RevId: 645409887
1 parent 917e7a4 commit 163cc4c

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

aqt/jax/v2/aqt_quantizer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class Quantizer:
4343
_calibrator: AbstractAqtCalibration | None = utils.static_field(default=None)
4444
# Round up the calibration to power of 2 (po2).
4545
po2_scale: bool = utils.static_field()
46+
scale_dtype: jnp.dtype | None = utils.static_field(default=None)
4647
# TODO(yichizh): Factor out auxilliary dataclasses into a separate file.
4748
context: utils.Context
4849

@@ -85,6 +86,8 @@ def calibrate(self, x, *, calibration_axes) -> aqt_tensor.QTensor:
8586
bound = self._calibrator.get_bound(x, shared_axes, self.context)
8687
abs_max_mapped_to = self.numerics.abs_val_mapped_to()
8788
scale = bound / abs_max_mapped_to
89+
if self.scale_dtype:
90+
scale = scale.astype(self.scale_dtype)
8891

8992
if self.po2_scale:
9093
# With floor the biggest value (we are using jnp.max) is in the range of

aqt/jax/v2/config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,24 @@ def get_numerics(bits):
354354
return cfg
355355

356356

357+
def set_scale_dtype(cfg: DotGeneral, scale_dtype: jnp.dtype):
358+
"""Set scale_dtype for dot_general config."""
359+
assert isinstance(
360+
cfg.fwd.dg_quantizer, aqt_dot_general.DefaultDotGeneralQuantizer
361+
)
362+
assert isinstance(
363+
cfg.dlhs.dg_quantizer, aqt_dot_general.DefaultDotGeneralQuantizer
364+
)
365+
assert isinstance(
366+
cfg.drhs.dg_quantizer, aqt_dot_general.DefaultDotGeneralQuantizer
367+
)
368+
cfg.fwd.dg_quantizer.lhs.scale_dtype = scale_dtype
369+
cfg.fwd.dg_quantizer.rhs.scale_dtype = scale_dtype
370+
cfg.dlhs.dg_quantizer.lhs.scale_dtype = scale_dtype
371+
cfg.dlhs.dg_quantizer.rhs.scale_dtype = scale_dtype
372+
cfg.drhs.dg_quantizer.lhs.scale_dtype = scale_dtype
373+
cfg.drhs.dg_quantizer.rhs.scale_dtype = scale_dtype
374+
357375
################################################################################
358376
# Functions below are auxiliary config creators.
359377

0 commit comments

Comments
 (0)