diff --git a/aqt/jax/v2/flax/aqt_flax.py b/aqt/jax/v2/flax/aqt_flax.py index d9738bba..c7db9cd1 100644 --- a/aqt/jax/v2/flax/aqt_flax.py +++ b/aqt/jax/v2/flax/aqt_flax.py @@ -99,6 +99,61 @@ def _get_singleton_axes(x: jnp.ndarray) -> list[utils.AxisIdx]: return qt +def aqt_promote_dtype( + lhs_in: jnp.ndarray, rhs_in: jnp.ndarray +) -> tuple[jnp.ndarray, jnp.ndarray]: + """Promotes the dtype of lhs_in and rhs_in. + + Args: + lhs_in: Left-hand-side array. + rhs_in: Right-hand-side array. + + Returns: + A tuple of the promoted lhs_in and rhs_in. + + We create a list of dtypes and hand-hold them because promote_dtype fails for + these dtypes. + """ + + manual_promotion_dtypes = [jnp.int4, jnp.float8_e4m3fn, jnp.float8_e5m2] + if ( + lhs_in.dtype in manual_promotion_dtypes + and rhs_in.dtype in manual_promotion_dtypes + ): + if lhs_in.dtype == rhs_in.dtype: + pass + else: + lhs_in = ( + jnp.float32(lhs_in) + if lhs_in.dtype == jnp.int4 + else jnp.bfloat16(lhs_in) + ) + rhs_in = ( + jnp.float32(rhs_in) + if rhs_in.dtype == jnp.int4 + else jnp.bfloat16(rhs_in) + ) + elif lhs_in.dtype in manual_promotion_dtypes: + lhs_in = ( + jnp.float32(lhs_in) + if lhs_in.dtype == jnp.int4 + else jnp.bfloat16(lhs_in) + ) + elif rhs_in.dtype in manual_promotion_dtypes: + rhs_in = ( + jnp.float32(rhs_in) + if rhs_in.dtype == jnp.int4 + else jnp.bfloat16(rhs_in) + ) + + if ( + lhs_in.dtype not in manual_promotion_dtypes + and rhs_in.dtype not in manual_promotion_dtypes + ): + lhs_in, rhs_in = nn.dtypes.promote_dtype(lhs_in, rhs_in) + return lhs_in, rhs_in + + class FreezerMode(enum.Enum): NONE = 1 CALIBRATION = 2 @@ -603,15 +658,7 @@ def __call__( # from being rejected by assertions in aqt_dot_general.py, line 522-526 and # 414. # TODO: b/322111904 - Handle this in more proper way. - # We hand-hold int4 because promote_dtype(int4, x) fails. - # (To avoid unintended promotion, 4-bit integers do not support - # implicit promotion.) - if lhs_in.dtype == jnp.int4: - lhs_in = jnp.float32(lhs_in) - if rhs_in.dtype == jnp.int4: - rhs_in = jnp.float32(rhs_in) - if lhs_in.dtype != jnp.int4 and rhs_in.dtype != jnp.int4: - lhs_in, rhs_in = nn.dtypes.promote_dtype(lhs_in, rhs_in) + lhs_in, rhs_in = aqt_promote_dtype(lhs_in, rhs_in) # yes_swap = whether einsum swaps [lhs,rhs] when passing them to dot_general einsum = functools.partial(aqt_dot_general.einsum, eqn=eqn)